From 37edef97d0d6a58b6238da11bf64737bb39e6d7f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 5 Feb 2024 22:28:17 +0100 Subject: [PATCH 001/296] WIP on rewriting type checker. --- futhark.cabal | 1 + src/Language/Futhark/TypeChecker.hs | 3 + src/Language/Futhark/TypeChecker/Terms2.hs | 1072 ++++++++++++++++++++ 3 files changed, 1076 insertions(+) create mode 100644 src/Language/Futhark/TypeChecker/Terms2.hs diff --git a/futhark.cabal b/futhark.cabal index 62190324f8..d9918a6554 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -414,6 +414,7 @@ library Language.Futhark.TypeChecker.Modules Language.Futhark.TypeChecker.Monad Language.Futhark.TypeChecker.Terms + Language.Futhark.TypeChecker.Terms2 Language.Futhark.TypeChecker.Terms.Loop Language.Futhark.TypeChecker.Terms.Monad Language.Futhark.TypeChecker.Terms.Pat diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index c82c0a70c8..2bded54dd9 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -34,6 +34,7 @@ import Language.Futhark.TypeChecker.Modules import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Names import Language.Futhark.TypeChecker.Terms +import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Prelude hiding (abs, mod) @@ -695,6 +696,8 @@ checkValBind vb = do attrs' <- mapM checkAttr attrs + void $ Terms2.checkValDef (fname, maybe_tdecl, tparams, params, body, loc) + (tparams', params', maybe_tdecl', rettype, body') <- checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs new file mode 100644 index 0000000000..6f12b8fcad --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -0,0 +1,1072 @@ +-- | A very WIP reimplementation of type checking of terms. +-- +-- The strategy is to split type checking into two (main) passes: +-- +-- 1) A size-agnostic pass that generates constraints (type Ct) which +-- are then solved offline to find a solution. This produces an AST +-- where most of the type annotations are just references to type +-- variables. Further, all the size-specific annotations (e.g. +-- existential sizes) just contain dummy values, such as empty lists. +-- The constraints use a type representation where all dimensions are +-- the same. However, we do try to take to store the sizes resulting +-- from explicit type ascriptions - these cannot refer to inferred +-- existentials, so it is safe to resolve them here. We don't do +-- anything with this information, however. +-- +-- 2) Pass (1) has given us a program where we know the types of +-- everything, but the sizes of nothing. Pass (2) then does +-- essentially size inference, much like the current/old type checker, +-- but of course with the massive benefit of already knowing the full +-- type of everything. This can be implemented using online constraint +-- solving (as before), or perhaps a completely syntax-driven +-- approach. +-- +-- As of this writing, only the constraint generation part of pass (1) +-- has been implemented, and it is very likely that some of the +-- constraints are actually wrong. Next step is to imlement the +-- solver. Currently all we do is dump the constraints to the +-- terminal. +-- +-- Also, no thought whatsoever has been put into quality of type +-- errors yet. However, I think an approach based on tacking source +-- information onto constraints should work well, as all constraints +-- ultimately originate from some bit of program syntax. +-- +-- Also no thought has been put into how to handle the liftedness +-- stuff. Since it does not really affect choices made during +-- inference, perhaps we can do it in a post-inference check. +module Language.Futhark.TypeChecker.Terms2 + ( checkValDef, + ) +where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor +import Data.Char (isAscii) +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe +import Data.Text qualified as T +import Debug.Trace +import Futhark.FreshNames qualified as FreshNames +import Futhark.MonadFreshNames hiding (newName) +import Futhark.Util (mapAccumLM) +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) +import Language.Futhark.TypeChecker.Monad qualified as TypeM +import Language.Futhark.TypeChecker.Types +import Language.Futhark.TypeChecker.Unify (Level, mkUsage) +import Prelude hiding (mod) + +data Inferred t + = NoneInferred + | Ascribed t + +instance Functor Inferred where + fmap _ NoneInferred = NoneInferred + fmap f (Ascribed t) = Ascribed (f t) + +data ValBinding + = BoundV [TypeParam] StructType + | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) + | EqualityF + deriving (Show) + +type Type = TypeBase () NoUniqueness + +toType :: TypeBase d u -> Type +toType = bimap (const ()) (const NoUniqueness) + +expType :: Exp -> Type +expType = toType . typeOf + +data Ct + = CtEq Type Type + | CtOneOf Type [PrimType] + | CtHasConstr Type Name [Type] + | CtHasField Type Name Type + deriving (Show) + +instance Pretty Ct where + pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtOneOf t1 ts) = pretty t1 <+> "∈" <+> pretty ts + pretty (CtHasConstr t1 k ts) = + pretty t1 <+> "~" <+> "... | " <+> hsep ("#" <> pretty k : map pretty ts) <+> " | ..." + pretty (CtHasField t1 k t) = + pretty t1 <+> "~" <+> braces ("..." <+> pretty k <> ":" <+> pretty t <+> "...") + +type Constraints = [Ct] + +-- | The substitution (or other information) known about a type +-- variable. +data TyVarSub + = -- | No substitution known yet; can be substituted with anything. + TyVarFree + | -- | This substitution has been found. + TyVarSub Type + deriving (Show) + +instance Pretty TyVarSub where + pretty TyVarFree = "free" + pretty (TyVarSub t) = "=" <> pretty t + +type TyVar = VName + +-- | If a VName is not in this map, it is assumed to be rigid. +type TyVars = M.Map TyVar TyVarSub + +data TermScope = TermScope + { scopeVtable :: M.Map VName ValBinding, + scopeTypeTable :: M.Map VName TypeBinding, + scopeModTable :: M.Map VName Mod + } + deriving (Show) + +instance Semigroup TermScope where + TermScope vt1 tt1 mt1 <> TermScope vt2 tt2 mt2 = + TermScope (vt2 `M.union` vt1) (tt2 `M.union` tt1) (mt1 `M.union` mt2) + +-- | Type checking happens with access to this environment. The +-- 'TermScope' will be extended during type-checking as bindings come into +-- scope. +data TermEnv = TermEnv + { termScope :: TermScope, + termLevel :: Level, + termOuterEnv :: Env, + termImportName :: ImportName + } + +-- | The state is a set of constraints and a counter for generating +-- type names. This is distinct from the usual counter we use for +-- generating unique names, as these will be user-visible. +data TermState = TermState + { termConstraints :: Constraints, + termTyVars :: TyVars, + termCounter :: !Int, + termWarnings :: Warnings, + termNameSource :: VNameSource + } + +newtype TermM a + = TermM + ( ReaderT + TermEnv + (StateT TermState (Except (Warnings, TypeError))) + a + ) + deriving + ( Monad, + Functor, + Applicative, + MonadReader TermEnv, + MonadState TermState + ) + +envToTermScope :: Env -> TermScope +envToTermScope env = + TermScope + { scopeVtable = vtable, + scopeTypeTable = envTypeTable env, + scopeModTable = envModTable env + } + where + vtable = M.map valBinding $ envVtable env + valBinding (TypeM.BoundV tps v) = BoundV tps v + +initialTermScope :: TermScope +initialTermScope = + TermScope + { scopeVtable = initialVtable, + scopeTypeTable = mempty, + scopeModTable = mempty + } + where + initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics + + prim = Scalar . Prim + arrow x y = Scalar $ Arrow mempty Unnamed Observe x y + + addIntrinsicF (name, IntrinsicMonoFun pts t) = + Just (name, BoundV [] $ arrow pts' $ RetType [] $ prim t) + where + pts' = case pts of + [pt] -> prim pt + _ -> Scalar $ tupleRecord $ map prim pts + addIntrinsicF (name, IntrinsicOverloadedFun ts pts rts) = + Just (name, OverloadedF ts pts rts) + addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = + Just + ( name, + BoundV tvs $ foldFunType pts rt + ) + addIntrinsicF (name, IntrinsicEquality) = + Just (name, EqualityF) + addIntrinsicF _ = Nothing + +runTermM :: TermM a -> TypeM a +runTermM (TermM m) = do + initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv + name <- askImportName + outer_env <- askEnv + src <- gets stateNameSource + let initial_env = + TermEnv + { termScope = initial_scope, + termLevel = 0, + termImportName = name, + termOuterEnv = outer_env + } + initial_state = + TermState + { termConstraints = mempty, + termTyVars = mempty, + termWarnings = mempty, + termNameSource = src, + termCounter = 0 + } + case runExcept (runStateT (runReaderT m initial_env) initial_state) of + Left (ws, e) -> do + warnings ws + throwError e + Right (a, TermState {termNameSource, termWarnings}) -> do + warnings termWarnings + modify $ \s -> s {stateNameSource = termNameSource} + pure a + +incLevel :: TermM a -> TermM a +incLevel = local $ \env -> env {termLevel = termLevel env + 1} + +incCounter :: TermM Int +incCounter = do + s <- get + put s {termCounter = termCounter s + 1} + pure $ termCounter s + +tyVarType :: (Monoid u) => TyVar -> TypeBase dim u +tyVarType v = Scalar $ TypeVar mempty (qualName v) [] + +newTyVar :: a -> Name -> TermM TyVar +newTyVar loc desc = do + i <- incCounter + v <- newID $ mkTypeVarName desc i + modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} + pure v + +newType :: (Monoid u) => a -> Name -> TermM (TypeBase dim u) +newType loc desc = tyVarType <$> newTyVar loc desc + +addCt :: Ct -> TermM () +addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} + +ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () +ctEq t1 t2 = addCt $ CtEq (toType t1) (toType t2) + +ctHasConstr :: TypeBase d1 u1 -> Name -> [TypeBase d2 u2] -> TermM () +ctHasConstr t1 k t2 = addCt $ CtHasConstr (toType t1) k (map toType t2) + +ctHasField :: TypeBase d1 u1 -> Name -> TypeBase d2 u2 -> TermM () +ctHasField t1 k t = addCt $ CtHasField (toType t1) k (toType t) + +ctOneOf :: TypeBase d1 u1 -> [PrimType] -> TermM () +ctOneOf t ts = addCt $ CtOneOf (toType t) ts + +localScope :: (TermScope -> TermScope) -> TermM a -> TermM a +localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} + +withEnv :: TermEnv -> Env -> TermEnv +withEnv tenv env = tenv {termScope = termScope tenv <> envToTermScope env} + +lookupQualNameEnv :: QualName VName -> TermM TermScope +lookupQualNameEnv (QualName [q] _) + | baseTag q <= maxIntrinsicTag = asks termScope -- Magical intrinsic module. +lookupQualNameEnv qn@(QualName quals _) = do + scope <- asks termScope + descend scope quals + where + descend scope [] = pure scope + descend scope (q : qs) + | Just (ModEnv q_scope) <- M.lookup q $ scopeModTable scope = + descend (envToTermScope q_scope) qs + | otherwise = + error $ "lookupQualNameEnv " <> show qn + +instance MonadError TypeError TermM where + throwError e = TermM $ do + ws <- gets termWarnings + throwError (ws, e) + + catchError (TermM m) f = + TermM $ m `catchError` f' + where + f' (_, e) = let TermM m' = f e in m' + +instance MonadTypeChecker TermM where + checkExpForSize = checkExp + + warnings ws = modify $ \s -> s {termWarnings = termWarnings s <> ws} + + warn loc problem = warnings $ singleWarning (locOf loc) problem + + newName v = do + s <- get + let (v', src') = FreshNames.newName (termNameSource s) v + put $ s {termNameSource = src'} + pure v' + + newID s = newName $ VName s 0 + + newTypeName name = do + i <- incCounter + newID $ mkTypeVarName name i + + bindVal v (TypeM.BoundV tps t) = localScope $ \scope -> + scope {scopeVtable = M.insert v (BoundV tps t) $ scopeVtable scope} + + lookupType qn = do + outer_env <- asks termOuterEnv + scope <- lookupQualNameEnv qn + case M.lookup (qualLeaf qn) $ scopeTypeTable scope of + Nothing -> error $ "lookupType: " <> show qn + Just (TypeAbbr l ps (RetType dims def)) -> + pure + ( ps, + RetType dims $ qualifyTypeVars outer_env (map typeParamName ps) (qualQuals qn) def, + l + ) + + typeError loc notes s = + throwError $ TypeError (locOf loc) notes s + +--- All the general machinery goes above. + +require :: T.Text -> [PrimType] -> Exp -> TermM Exp +require why ts e = do + ctOneOf (typeOf e) ts + pure e + +-- | Create a new type name and insert it (unconstrained) in the set +-- of type variables. +instTypeParam :: + (Monoid as) => + QualName VName -> + SrcLoc -> + TypeParam -> + TermM (VName, Subst (RetTypeBase dim as)) +instTypeParam qn loc tparam = do + i <- incCounter + let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) + v <- newID $ mkTypeVarName name i + case tparam of + TypeParamType {} -> do + modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} + pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) + TypeParamDim {} -> + pure (v, ExpSubst $ sizeFromName (qualName v) loc) + +-- | Instantiate a type scheme with fresh type variables for its type +-- parameters. Returns the names of the fresh type variables, the +-- instance list, and the instantiated type. +instTypeScheme :: + QualName VName -> + SrcLoc -> + [TypeParam] -> + StructType -> + TermM ([VName], StructType) +instTypeScheme qn loc tparams t = do + (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do + case tparam of + TypeParamType x _ _ -> do + i <- incCounter + let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) + v <- newID $ mkTypeVarName name i + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) [])) + TypeParamDim {} -> + pure Nothing + let t' = applySubst (`lookup` substs) t + pure (names, t') + +lookupMod :: QualName VName -> TermM Mod +lookupMod qn@(QualName _ name) = do + scope <- lookupQualNameEnv qn + case M.lookup name $ scopeModTable scope of + Nothing -> error $ "lookupMod: " <> show qn + Just m -> pure m + +lookupVar :: SrcLoc -> QualName VName -> TermM StructType +lookupVar loc qn@(QualName qs name) = do + scope <- lookupQualNameEnv qn + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newType loc "t" + pure $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ + Scalar $ + Arrow mempty Unnamed Observe argtype $ + RetType [] $ + Scalar $ + Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newType loc "t" + ctOneOf argtype ts + let (pts', rt') = instOverloaded (argtype :: StructType) pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + where + instOverloaded argtype pts rt = + ( map (maybe argtype (Scalar . Prim)) pts, + maybe argtype (Scalar . Prim) rt + ) + +bind :: + [Ident StructType] -> + TermM a -> + TermM a +bind idents = localScope (`bindVars` idents) + where + bindVars = foldl bindVar + + bindVar scope (Ident name (Info tp) _) = + scope + { scopeVtable = M.insert name (BoundV [] tp) $ scopeVtable scope + } + +-- All this complexity is just so we can handle un-suffixed numeric +-- literals in patterns. +patLitMkType :: PatLit -> SrcLoc -> TermM ParamType +patLitMkType (PatLitInt _) loc = do + t <- newType loc "t" + ctOneOf t anyNumberType + pure t +patLitMkType (PatLitFloat _) loc = do + t <- newType loc "t" + ctOneOf t anyFloatType + pure t +patLitMkType (PatLitPrim v) _ = + pure $ Scalar $ Prim $ primValueType v + +checkPat' :: + PatBase NoInfo VName ParamType -> + Inferred ParamType -> + TermM (Pat ParamType) +checkPat' (PatParens p loc) t = + PatParens <$> checkPat' p t <*> pure loc +checkPat' (PatAttr attr p loc) t = + PatAttr <$> checkAttr attr <*> checkPat' p t <*> pure loc +checkPat' (Id name NoInfo loc) (Ascribed t) = + pure $ Id name (Info t) loc +checkPat' (Id name NoInfo loc) NoneInferred = do + t <- newType loc "t" + pure $ Id name (Info t) loc +checkPat' (Wildcard _ loc) (Ascribed t) = + pure $ Wildcard (Info t) loc +checkPat' (Wildcard NoInfo loc) NoneInferred = do + t <- newType loc "t" + pure $ Wildcard (Info t) loc +checkPat' p@(TuplePat ps loc) (Ascribed t) + | Just ts <- isTupleRecord t, + length ts == length ps = + TuplePat + <$> zipWithM checkPat' ps (map Ascribed ts) + <*> pure loc + | otherwise = do + ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") + ctEq (Scalar (tupleRecord ps_t)) t + checkPat' p $ Ascribed $ toParam Observe $ Scalar $ tupleRecord ps_t +checkPat' (TuplePat ps loc) NoneInferred = + TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc +checkPat' p@(RecordPat p_fs loc) (Ascribed t) + | Scalar (Record t_fs) <- t, + L.sort (map fst p_fs) == L.sort (M.keys t_fs) = + RecordPat . M.toList <$> check t_fs <*> pure loc + | otherwise = do + p_fs' <- traverse (const $ newType loc "t") $ M.fromList p_fs + ctEq (Scalar (Record p_fs') :: ParamType) t + checkPat' p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') + where + check t_fs = + traverse (uncurry checkPat') $ + M.intersectionWith (,) (M.fromList p_fs) (fmap Ascribed t_fs) +checkPat' (RecordPat fs loc) NoneInferred = + RecordPat . M.toList + <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) + <*> pure loc +checkPat' (PatAscription p t loc) maybe_outer_t = do + (t', _, RetType dims st, _) <- checkTypeExp t + + case maybe_outer_t of + Ascribed outer_t -> do + ctEq st outer_t + PatAscription + <$> checkPat' p (Ascribed (resToParam st)) + <*> pure t' + <*> pure loc + NoneInferred -> + PatAscription + <$> checkPat' p (Ascribed (resToParam st)) + <*> pure t' + <*> pure loc +checkPat' (PatLit l NoInfo loc) (Ascribed t) = do + t' <- patLitMkType l loc + addCt $ CtEq (toType t') (toType t) + pure $ PatLit l (Info t') loc +checkPat' (PatLit l NoInfo loc) NoneInferred = do + t' <- patLitMkType l loc + pure $ PatLit l (Info t') loc +checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) + | Just ts <- M.lookup n cs = do + when (length ps /= length ts) $ + typeError loc mempty $ + "Pattern #" + <> pretty n + <> " expects" + <+> pretty (length ps) + <+> "constructor arguments, but type provides" + <+> pretty (length ts) + <+> "arguments." + ps' <- zipWithM checkPat' ps $ map Ascribed ts + pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc +checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do + t' <- newType loc "t" + ps' <- forM ps $ \p -> do + p_t <- newType (srclocOf p) "t" + checkPat' p $ Ascribed p_t + ctHasConstr (t' :: ParamType) n $ map patternStructType ps' + pure $ PatConstr n (Info t) ps' loc +checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do + ps' <- mapM (`checkPat'` NoneInferred) ps + t <- newType loc "t" + ctHasConstr t n $ map patternStructType ps' + pure $ PatConstr n (Info t) ps' loc + +checkPat :: + PatBase NoInfo VName (TypeBase Size u) -> + Inferred StructType -> + (Pat ParamType -> TermM a) -> + TermM a +checkPat p t m = + m =<< checkPat' (fmap (toParam Observe) p) (fmap (toParam Observe) t) + +-- | Bind @let@-bound sizes. This is usually followed by 'bindletPat' +-- immediately afterwards. +bindSizes :: [SizeBinder VName] -> TermM a -> TermM a +bindSizes [] m = m -- Minor optimisation. +bindSizes sizes m = bind (map sizeWithType sizes) m + where + sizeWithType size = + Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) + +bindLetPat :: + PatBase NoInfo VName (TypeBase Size u) -> + StructType -> + (Pat ParamType -> TermM a) -> + TermM a +bindLetPat p t m = do + checkPat p (Ascribed t) $ \p' -> + bind (patIdents (fmap toStruct p')) $ + m p' + +typeParamIdent :: TypeParam -> Maybe (Ident StructType) +typeParamIdent (TypeParamDim v loc) = + Just $ Ident v (Info $ Scalar $ Prim $ Signed Int64) loc +typeParamIdent _ = Nothing + +bindTypes :: + [(VName, TypeBinding)] -> + TermM a -> + TermM a +bindTypes tbinds = localScope extend + where + extend scope = + scope + { scopeTypeTable = M.fromList tbinds <> scopeTypeTable scope + } + +bindTypeParams :: [TypeParam] -> TermM a -> TermM a +bindTypeParams tparams = + bind (mapMaybe typeParamIdent tparams) + . bindTypes (mapMaybe typeParamType tparams) + where + typeParamType (TypeParamType l v _) = + Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) + typeParamType TypeParamDim {} = + Nothing + +bindParams :: + [TypeParam] -> + [PatBase NoInfo VName ParamType] -> + ([Pat ParamType] -> TermM a) -> + TermM a +bindParams tps orig_ps m = bindTypeParams tps $ do + let descend ps' (p : ps) = + checkPat p NoneInferred $ \p' -> + bind (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps + descend ps' [] = m $ reverse ps' + + incLevel $ descend [] orig_ps + +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM TyVar +checkApply loc (fname, _) ftype arg = do + a <- newType loc "a" + b <- newTyVar loc "b" + ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) + ctEq a (expType arg) + pure b + +checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] +checkSlice = mapM checkDimIndex + where + checkDimIndex (DimFix i) = + DimFix <$> check i + checkDimIndex (DimSlice i j s) = + DimSlice <$> traverse check i <*> traverse check j <*> traverse check s + + check = require "use as index" anySignedType <=< checkExp + +isSlice :: DimIndexBase f vn -> Bool +isSlice DimSlice {} = True +isSlice DimFix {} = False + +-- Add constraints saying that the first type has a (potentially +-- nested) field containing the second type. +mustHaveFields :: + SrcLoc -> + TypeBase d1 u1 -> + [Name] -> + TypeBase d2 u2 -> + TermM () +mustHaveFields loc t [] ve_t = ctEq t ve_t +mustHaveFields loc t (f : fs) ve_t = do + f_t :: Type <- newType loc "ft" + ctHasField t f f_t + mustHaveFields loc f_t fs ve_t + +checkCase :: + StructType -> + CaseBase NoInfo VName -> + TermM (CaseBase Info VName, StructType) +checkCase mt (CasePat p e loc) = + bindLetPat p mt $ \p' -> do + e' <- checkExp e + pure (CasePat (fmap toStruct p') e' loc, typeOf e') + +checkCases :: + StructType -> + NE.NonEmpty (CaseBase NoInfo VName) -> + TermM (NE.NonEmpty (CaseBase Info VName), StructType) +checkCases mt rest_cs = + case NE.uncons rest_cs of + (c, Nothing) -> do + (c', t) <- checkCase mt c + pure (NE.singleton c', t) + (c, Just cs) -> do + (c', c_t) <- checkCase mt c + (cs', cs_t) <- checkCases mt cs + ctEq c_t cs_t + pure (NE.cons c' cs', c_t) + +-- | An unmatched pattern. Used in in the generation of +-- unmatched pattern warnings by the type checker. +data Unmatched p + = UnmatchedNum p [PatLit] + | UnmatchedBool p + | UnmatchedConstr p + | Unmatched p + deriving (Functor, Show) + +instance Pretty (Unmatched (Pat StructType)) where + pretty um = case um of + (UnmatchedNum p nums) -> pretty' p <+> "where p is not one of" <+> pretty nums + (UnmatchedBool p) -> pretty' p + (UnmatchedConstr p) -> pretty' p + (Unmatched p) -> pretty' p + where + pretty' (PatAscription p t _) = pretty p <> ":" <+> pretty t + pretty' (PatParens p _) = parens $ pretty' p + pretty' (PatAttr _ p _) = parens $ pretty' p + pretty' (Id v _ _) = prettyName v + pretty' (TuplePat pats _) = parens $ commasep $ map pretty' pats + pretty' (RecordPat fs _) = braces $ commasep $ map ppField fs + where + ppField (name, t) = pretty (nameToString name) <> equals <> pretty' t + pretty' Wildcard {} = "_" + pretty' (PatLit e _ _) = pretty e + pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) + +checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) +-- +checkExp (Var qn _ loc) = do + t <- lookupVar loc qn + pure $ Var qn (Info t) loc +checkExp (OpSection op _ loc) = do + ftype <- lookupVar loc op + pure $ OpSection op (Info ftype) loc +checkExp (Negate arg loc) = do + arg' <- require "numeric negation" anyNumberType =<< checkExp arg + pure $ Negate arg' loc +checkExp (Not arg loc) = do + arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg + pure $ Not arg' loc +checkExp (Hole NoInfo loc) = + Hole <$> (Info <$> newType loc "hole") <*> pure loc +checkExp (Parens e loc) = + Parens <$> checkExp e <*> pure loc +checkExp (TupLit es loc) = + TupLit <$> mapM checkExp es <*> pure loc +checkExp (QualParens (modname, modnameloc) e loc) = do + mod <- lookupMod modname + case mod of + ModEnv env -> local (`withEnv` env) $ do + e' <- checkExp e + pure $ QualParens (modname, modnameloc) e' loc + ModFun {} -> + typeError loc mempty . withIndexLink "module-is-parametric" $ + "Module" <+> pretty modname <+> " is a parametric module." +-- +checkExp (IntLit x NoInfo loc) = do + t <- newType loc "num" + ctOneOf t anyNumberType + pure $ IntLit x (Info t) loc +checkExp (FloatLit x NoInfo loc) = do + t <- newType loc "float" + ctOneOf t anyFloatType + pure $ FloatLit x (Info t) loc +checkExp (Literal v loc) = + pure $ Literal v loc +checkExp (StringLit vs loc) = + pure $ StringLit vs loc +checkExp (ArrayLit es _ loc) = do + -- TODO: this will produce an enormous number of constraints and + -- type variables for pathologically large arrays with + -- type-unsuffixed integers. Add some special case that handles that + -- more efficiently. + et <- newType loc "et" + es' <- forM es $ \e -> do + e' <- checkExp e + ctEq (typeOf e') et + pure e' + let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et + pure $ ArrayLit es' (Info arr_t) loc +checkExp (RecordLit fs loc) = + RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc + where + checkField (RecordFieldExplicit f e rloc) = do + errIfAlreadySet f rloc + modify $ M.insert f rloc + RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc + checkField (RecordFieldImplicit name NoInfo rloc) = do + errIfAlreadySet (baseName name) rloc + t <- lift $ lookupVar rloc $ qualName name + modify $ M.insert (baseName name) rloc + pure $ RecordFieldImplicit name (Info t) rloc + + errIfAlreadySet f rloc = do + maybe_sloc <- gets $ M.lookup f + case maybe_sloc of + Just sloc -> + lift . typeError rloc mempty $ + "Field" + <+> dquotes (pretty f) + <+> "previously defined at" + <+> pretty (locStrRel rloc sloc) + <> "." + Nothing -> pure () + +-- +checkExp (Attr info e loc) = + Attr <$> checkAttr info <*> checkExp e <*> pure loc +checkExp (Assert e1 e2 NoInfo loc) = do + e1' <- require "being asserted" [Bool] =<< checkExp e1 + e2' <- checkExp e2 + pure $ Assert e1' e2' (Info (prettyText e1)) loc +-- +checkExp (Constr name es NoInfo loc) = do + t <- newType loc "t" + es' <- mapM checkExp es + ctHasConstr t name $ map typeOf es' + pure $ Constr name es' (Info t) loc +-- +checkExp (AppExp (Apply fe args loc) NoInfo) = do + fe' <- checkExp fe + ((_, rt), args') <- mapAccumLM onArg (0, typeOf fe') args + + pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt [] + where + fname = + case fe of + Var v _ _ -> Just v + _ -> Nothing + + onArg (i, f_t) (_, arg) = do + arg' <- checkExp arg + rt <- checkApply loc (fname, i) (toType f_t) arg' + pure + ( (i + 1, tyVarType rt), + (Info Nothing, arg') + ) +-- +checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do + ftype <- lookupVar oploc op + e1' <- checkExp e1 + e2' <- checkExp e2 + + rt1 <- checkApply loc (Just op, 0) (toType ftype) e1' + rt2 <- checkApply loc (Just op, 1) (tyVarType rt1) e2' + + pure $ + AppExp + (BinOp (op, oploc) (Info ftype) (e1', Info Nothing) (e2', Info Nothing) loc) + (Info (AppRes (tyVarType rt2) [])) +-- +checkExp (OpSectionLeft op _ e _ _ loc) = do + optype <- lookupVar loc op + e' <- checkExp e + rt <- checkApply loc (Just op, 0) (toType optype) e' + pure $ + OpSectionLeft + op + (Info optype) + e' + -- Dummy types. + ( Info (Unnamed, Scalar $ Prim Bool, Nothing), + Info (Unnamed, Scalar $ Prim Bool) + ) + (Info (RetType [] (tyVarType rt)), Info []) + loc +-- +checkExp (OpSectionRight op _ e _ NoInfo loc) = do + optype <- lookupVar loc op + e' <- checkExp e + t1 <- newType loc "t" + rt <- newType loc "rt" + ctEq optype $ foldFunType [t1, toParam Observe $ typeOf e'] $ RetType [] rt + pure $ + OpSectionRight + op + (Info optype) + e' + -- Dummy types. + ( Info (Unnamed, Scalar $ Prim Bool), + Info (Unnamed, Scalar $ Prim Bool, Nothing) + ) + (Info $ RetType [] rt) + loc +-- +checkExp (ProjectSection fields NoInfo loc) = do + a <- newType loc "a" + b <- newType loc "b" + mustHaveFields loc a fields b + let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b + pure $ ProjectSection fields (Info ft) loc +-- +checkExp (Lambda params body rettype NoInfo loc) = do + bindParams [] params $ \params' -> do + body' <- checkExp body + rettype_te' <- case rettype of + Just rettype_te -> do + (rettype_te', _, RetType _ st, _) <- checkTypeExp rettype_te + ctEq (typeOf body') st + pure $ Just rettype_te' + Nothing -> pure Nothing + let ret = RetType [] $ toRes Nonunique $ typeOf body' + pure $ Lambda params' body' rettype_te' (Info ret) loc +-- +checkExp (AppExp (LetPat sizes pat e body loc) _) = do + e' <- checkExp e + + bindSizes sizes . incLevel . bindLetPat pat (typeOf e') $ \pat' -> do + body' <- incLevel $ checkExp body + pure $ + AppExp + (LetPat sizes (fmap toStruct pat') e' body' loc) + (Info $ AppRes (typeOf body') []) +-- +checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do + (tparams', params', maybe_retdecl', rettype, e') <- + bindParams tparams params $ \params' -> do + e' <- checkExp e + let ret = RetType [] $ toRes Nonunique $ typeOf e' + pure (tparams, params', undefined, ret, e') + + let entry = BoundV tparams' $ funType params' rettype + bindF scope = + scope + { scopeVtable = M.insert name entry $ scopeVtable scope + } + body' <- localScope bindF $ checkExp body + + pure $ + AppExp + ( LetFun + name + (tparams', params', maybe_retdecl', Info rettype, e') + body' + loc + ) + (Info $ AppRes (typeOf body') []) +-- +checkExp (AppExp (Range start maybe_step end loc) _) = do + start' <- checkExp' start + maybe_step' <- traverse checkExp' maybe_step + end' <- traverse checkExp' end + range_t <- newType loc "range" + ctEq range_t $ arrayOf (Shape [()]) (toType (typeOf start')) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] + where + checkExp' = require "use in range expression" anyIntType <=< checkExp +-- +checkExp (Project k e NoInfo loc) = do + e' <- checkExp e + kt <- newType loc "t" + ctHasField (typeOf e') k kt + pure $ Project k e' (Info kt) loc +-- +checkExp (RecordUpdate src fields ve NoInfo loc) = do + src' <- checkExp src + ve' <- checkExp ve + mustHaveFields loc (typeOf src') fields (typeOf ve') + pure $ RecordUpdate src' fields ve' (Info (typeOf src')) loc +-- +checkExp (IndexSection slice NoInfo loc) = do + slice' <- checkSlice slice + index_arg_t <- newType loc "index" + index_elem_t <- newType loc "index_elem" + index_res_t <- newType loc "index_res" + let num_slices = length $ filter isSlice slice + ctEq index_arg_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t + ctEq index_res_t $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ RetType [] index_res_t + pure $ IndexSection slice' (Info ft) loc +-- +checkExp (AppExp (Index e slice loc) _) = do + e' <- checkExp e + slice' <- checkSlice slice + index_t <- newType loc "index" + index_elem_t <- newType loc "index_elem" + let num_slices = length $ filter isSlice slice + ctEq index_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t + ctEq (typeOf e') $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) +-- +checkExp (Update src slice ve loc) = do + src' <- checkExp src + slice' <- checkSlice slice + ve' <- checkExp ve + let num_slices = length $ filter isSlice slice + update_elem_t <- newType loc "update_elem" + ctEq (typeOf src') $ arrayOf (Shape (replicate (length slice) ())) update_elem_t + ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + pure $ Update src' slice' ve' loc +-- +checkExp (AppExp (LetWith dest src slice ve body loc) _) = do + src_t <- lookupVar (srclocOf src) $ qualName $ identName src + let src' = src {identType = Info src_t} + dest' = dest {identType = Info src_t} + slice' <- checkSlice slice + ve' <- checkExp ve + let num_slices = length $ filter isSlice slice + update_elem_t <- newType loc "update_elem" + ctEq src_t $ arrayOf (Shape (replicate (length slice) ())) update_elem_t + ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + bind [dest'] $ do + body' <- checkExp body + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) +-- +checkExp (AppExp (If e1 e2 e3 loc) _) = do + e1' <- checkExp e1 + e2' <- checkExp e2 + e3' <- checkExp e3 + + ctEq (typeOf e1') (Scalar (Prim Bool) :: Type) + ctEq (typeOf e2') (typeOf e3') + + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e1') []) +-- +checkExp (AppExp (Match e cs loc) _) = do + e' <- checkExp e + (cs', t) <- checkCases (typeOf e') cs + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t []) +-- +checkExp (AppExp (Loop _ pat arg form body loc) _) = do + arg' <- checkExp arg + bindLetPat pat (typeOf arg') $ \pat' -> do + (form', body') <- + case form of + For (Ident i _ iloc) bound -> do + bound' <- require "loop bound" anyIntType =<< checkExp bound + let i' = Ident i (Info (typeOf bound')) iloc + bind [i'] $ do + body' <- checkExp body + ctEq (typeOf arg') (typeOf body') + pure (For i' bound', body') + While cond -> do + cond' <- checkExp cond + body' <- checkExp body + ctEq (typeOf arg') (typeOf body') + pure (While cond', body') + ForIn elemp arr -> do + arr' <- checkExp arr + elem_t <- newType elemp "elem" + ctEq (typeOf arr') $ arrayOf (Shape [()]) (toType elem_t) + bindLetPat elemp elem_t $ \elemp' -> do + body' <- checkExp body + pure (ForIn (toStruct <$> elemp') arr', body') + pure $ + AppExp + (Loop [] pat' arg' form' body' loc) + (Info (AppRes (patternStructType pat') [])) +-- +checkExp (Ascript e te loc) = do + e' <- checkExp e + (te', _, RetType _ st, _) <- checkTypeExp te + ctEq (typeOf e') st + pure $ Ascript e' te' loc +checkExp (Coerce e te NoInfo loc) = do + e' <- checkExp e + (te', _, RetType _ st, _) <- checkTypeExp te + ctEq (typeOf e') st + pure $ Coerce e' te' (Info (toStruct st)) loc + +-- +-- + +checkValDef :: + ( VName, + Maybe (TypeExp NoInfo VName), + [TypeParam], + [PatBase NoInfo VName ParamType], + ExpBase NoInfo VName, + SrcLoc + ) -> + TypeM + ( [TypeParam], + [Pat ParamType], + Maybe (TypeExp Info VName), + ResRetType, + Exp + ) +checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do + bindParams tparams params $ \params' -> do + body' <- checkExp body + cts <- gets termConstraints + tyvars <- gets termTyVars + traceM $ + unlines + [ "function " <> prettyNameString fname, + "constraints:", + prettyString cts, + "tyvars:", + prettyString $ map (first prettyNameString) $ M.toList tyvars + ] + pure (undefined, params', undefined, undefined, body') From ab529bb4031019053101eef8cf76cc628e529437 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 6 Feb 2024 11:26:11 +0100 Subject: [PATCH 002/296] Do not think of overloading as constraints. --- src/Language/Futhark/TypeChecker/Terms2.hs | 97 +++++++++++----------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6f12b8fcad..52106ad28d 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -256,24 +256,33 @@ newTyVar loc desc = do modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} pure v -newType :: (Monoid u) => a -> Name -> TermM (TypeBase dim u) +newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) newType loc desc = tyVarType <$> newTyVar loc desc +newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> TypeBase dim u -> TermM (TypeBase dim u) +newTypeWithField loc desc k t = do + rt <- newType loc desc + addCt $ CtHasField (toType rt) k (toType t) + pure rt + +newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [Type] -> TermM (TypeBase dim u) +newTypeWithConstr loc desc k ts = do + t <- newType loc desc + addCt $ CtHasConstr (toType t) k ts + pure t + +newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase dim u) +newTypeOverloaded loc name pts = do + t <- newType loc name + addCt $ CtOneOf (toType t) pts + pure t + addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () ctEq t1 t2 = addCt $ CtEq (toType t1) (toType t2) -ctHasConstr :: TypeBase d1 u1 -> Name -> [TypeBase d2 u2] -> TermM () -ctHasConstr t1 k t2 = addCt $ CtHasConstr (toType t1) k (map toType t2) - -ctHasField :: TypeBase d1 u1 -> Name -> TypeBase d2 u2 -> TermM () -ctHasField t1 k t = addCt $ CtHasField (toType t1) k (toType t) - -ctOneOf :: TypeBase d1 u1 -> [PrimType] -> TermM () -ctOneOf t ts = addCt $ CtOneOf (toType t) ts - localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -344,8 +353,9 @@ instance MonadTypeChecker TermM where --- All the general machinery goes above. require :: T.Text -> [PrimType] -> Exp -> TermM Exp -require why ts e = do - ctOneOf (typeOf e) ts +require why pts e = do + t :: Type <- newTypeOverloaded (srclocOf e) "t" pts + ctEq t $ expType e pure e -- | Create a new type name and insert it (unconstrained) in the set @@ -419,8 +429,7 @@ lookupVar loc qn@(QualName qs name) = do Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do - argtype <- newType loc "t" - ctOneOf argtype ts + argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded (argtype :: StructType) pts rt pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where @@ -445,14 +454,10 @@ bind idents = localScope (`bindVars` idents) -- All this complexity is just so we can handle un-suffixed numeric -- literals in patterns. patLitMkType :: PatLit -> SrcLoc -> TermM ParamType -patLitMkType (PatLitInt _) loc = do - t <- newType loc "t" - ctOneOf t anyNumberType - pure t -patLitMkType (PatLitFloat _) loc = do - t <- newType loc "t" - ctOneOf t anyFloatType - pure t +patLitMkType (PatLitInt _) loc = + newTypeOverloaded loc "t" anyNumberType +patLitMkType (PatLitFloat _) loc = + newTypeOverloaded loc "t" anyFloatType patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v @@ -538,16 +543,15 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) ps' <- zipWithM checkPat' ps $ map Ascribed ts pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do - t' <- newType loc "t" ps' <- forM ps $ \p -> do p_t <- newType (srclocOf p) "t" checkPat' p $ Ascribed p_t - ctHasConstr (t' :: ParamType) n $ map patternStructType ps' - pure $ PatConstr n (Info t) ps' loc + t' <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' + ctEq t' t + pure $ PatConstr n (Info t') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps - t <- newType loc "t" - ctHasConstr t n $ map patternStructType ps' + t <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' pure $ PatConstr n (Info t) ps' loc checkPat :: @@ -640,17 +644,18 @@ isSlice DimFix {} = False -- Add constraints saying that the first type has a (potentially -- nested) field containing the second type. -mustHaveFields :: - SrcLoc -> - TypeBase d1 u1 -> - [Name] -> - TypeBase d2 u2 -> - TermM () -mustHaveFields loc t [] ve_t = ctEq t ve_t +mustHaveFields :: SrcLoc -> Type -> [Name] -> Type -> TermM () +mustHaveFields _ t [] ve_t = + -- This case is probably never reached. + ctEq t ve_t +mustHaveFields loc t [f] ve_t = do + rt :: Type <- newTypeWithField loc "ft" f ve_t + ctEq t rt mustHaveFields loc t (f : fs) ve_t = do - f_t :: Type <- newType loc "ft" - ctHasField t f f_t - mustHaveFields loc f_t fs ve_t + ft :: Type <- newType loc "ft" + rt <- newTypeWithField loc "rt" f ft + mustHaveFields loc ft fs ve_t + ctEq t rt checkCase :: StructType -> @@ -735,12 +740,10 @@ checkExp (QualParens (modname, modnameloc) e loc) = do "Module" <+> pretty modname <+> " is a parametric module." -- checkExp (IntLit x NoInfo loc) = do - t <- newType loc "num" - ctOneOf t anyNumberType + t <- newTypeOverloaded loc "num" anyNumberType pure $ IntLit x (Info t) loc checkExp (FloatLit x NoInfo loc) = do - t <- newType loc "float" - ctOneOf t anyFloatType + t <- newTypeOverloaded loc "float" anyFloatType pure $ FloatLit x (Info t) loc checkExp (Literal v loc) = pure $ Literal v loc @@ -792,9 +795,8 @@ checkExp (Assert e1 e2 NoInfo loc) = do pure $ Assert e1' e2' (Info (prettyText e1)) loc -- checkExp (Constr name es NoInfo loc) = do - t <- newType loc "t" es' <- mapM checkExp es - ctHasConstr t name $ map typeOf es' + t <- newTypeWithConstr loc "t" name $ map expType es' pure $ Constr name es' (Info t) loc -- checkExp (AppExp (Apply fe args loc) NoInfo) = do @@ -866,7 +868,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do checkExp (ProjectSection fields NoInfo loc) = do a <- newType loc "a" b <- newType loc "b" - mustHaveFields loc a fields b + mustHaveFields loc (toType a) fields (toType b) let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b pure $ ProjectSection fields (Info ft) loc -- @@ -928,14 +930,15 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e - kt <- newType loc "t" - ctHasField (typeOf e') k kt + kt <- newType loc "kt" + t <- newTypeWithField loc "t" k kt + ctEq (typeOf e') t pure $ Project k e' (Info kt) loc -- checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src ve' <- checkExp ve - mustHaveFields loc (typeOf src') fields (typeOf ve') + mustHaveFields loc (expType src') fields (expType ve') pure $ RecordUpdate src' fields ve' (Info (typeOf src')) loc -- checkExp (IndexSection slice NoInfo loc) = do From d36f15ede541972667ebcb56a5ad65b159fd4082 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 6 Feb 2024 11:37:38 +0100 Subject: [PATCH 003/296] Move overloading from constraints to tyvars. --- src/Language/Futhark/TypeChecker/Terms2.hs | 70 ++++++++++------------ 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 52106ad28d..036338bc1d 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,7 +59,7 @@ import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types -import Language.Futhark.TypeChecker.Unify (Level, mkUsage) +import Language.Futhark.TypeChecker.Unify (Level) import Prelude hiding (mod) data Inferred t @@ -84,40 +84,36 @@ toType = bimap (const ()) (const NoUniqueness) expType :: Exp -> Type expType = toType . typeOf -data Ct - = CtEq Type Type - | CtOneOf Type [PrimType] - | CtHasConstr Type Name [Type] - | CtHasField Type Name Type +data Ct = CtEq Type Type deriving (Show) instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtOneOf t1 ts) = pretty t1 <+> "∈" <+> pretty ts - pretty (CtHasConstr t1 k ts) = - pretty t1 <+> "~" <+> "... | " <+> hsep ("#" <> pretty k : map pretty ts) <+> " | ..." - pretty (CtHasField t1 k t) = - pretty t1 <+> "~" <+> braces ("..." <+> pretty k <> ":" <+> pretty t <+> "...") type Constraints = [Ct] --- | The substitution (or other information) known about a type --- variable. -data TyVarSub - = -- | No substitution known yet; can be substituted with anything. +-- | Information about a type variable. +data TyVarInfo + = -- | Can be substituted with anything. TyVarFree - | -- | This substitution has been found. - TyVarSub Type + | -- | Can only be substituted with these primitive types. + TyVarPrim [PrimType] + | -- | Must be a record with these fields. + TyVarRecord (M.Map Name Type) + | -- | Must be a sum type with these fields. + TyVarSum (M.Map Name [Type]) deriving (Show) -instance Pretty TyVarSub where +instance Pretty TyVarInfo where pretty TyVarFree = "free" - pretty (TyVarSub t) = "=" <> pretty t + pretty (TyVarPrim pts) = "∈" <+> pretty pts + pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs type TyVar = VName -- | If a VName is not in this map, it is assumed to be rigid. -type TyVars = M.Map TyVar TyVarSub +type TyVars = M.Map TyVar TyVarInfo data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, @@ -249,33 +245,30 @@ incCounter = do tyVarType :: (Monoid u) => TyVar -> TypeBase dim u tyVarType v = Scalar $ TypeVar mempty (qualName v) [] -newTyVar :: a -> Name -> TermM TyVar -newTyVar loc desc = do +newTyVarWith :: a -> Name -> TyVarInfo -> TermM TyVar +newTyVarWith loc desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i - modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} + modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} pure v +newTyVar :: a -> Name -> TermM TyVar +newTyVar loc desc = newTyVarWith loc desc TyVarFree + newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) newType loc desc = tyVarType <$> newTyVar loc desc newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> TypeBase dim u -> TermM (TypeBase dim u) -newTypeWithField loc desc k t = do - rt <- newType loc desc - addCt $ CtHasField (toType rt) k (toType t) - pure rt +newTypeWithField loc desc k t = + tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k $ toType t) newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [Type] -> TermM (TypeBase dim u) -newTypeWithConstr loc desc k ts = do - t <- newType loc desc - addCt $ CtHasConstr (toType t) k ts - pure t +newTypeWithConstr loc desc k ts = + tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts) newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase dim u) -newTypeOverloaded loc name pts = do - t <- newType loc name - addCt $ CtOneOf (toType t) pts - pure t +newTypeOverloaded loc name pts = + tyVarType <$> newTyVarWith loc name (TyVarPrim pts) addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} @@ -1041,9 +1034,6 @@ checkExp (Coerce e te NoInfo loc) = do ctEq (typeOf e') st pure $ Coerce e' te' (Info (toStruct st)) loc --- --- - checkValDef :: ( VName, Maybe (TypeExp NoInfo VName), @@ -1068,8 +1058,8 @@ checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do unlines [ "function " <> prettyNameString fname, "constraints:", - prettyString cts, + unlines $ map prettyString cts, "tyvars:", - prettyString $ map (first prettyNameString) $ M.toList tyvars + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] pure (undefined, params', undefined, undefined, body') From c032bcff60489780acdf258d1c2a3049cc95d2f4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 6 Feb 2024 13:10:47 +0100 Subject: [PATCH 004/296] Make some room for a solver. --- futhark.cabal | 1 + .../Futhark/TypeChecker/Constraints.hs | 78 +++++++++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 83 ++++++++----------- 3 files changed, 114 insertions(+), 48 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/Constraints.hs diff --git a/futhark.cabal b/futhark.cabal index d9918a6554..c89eacb214 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -409,6 +409,7 @@ library Language.Futhark.Tuple Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption + Language.Futhark.TypeChecker.Constraints Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs new file mode 100644 index 0000000000..e33623704d --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -0,0 +1,78 @@ +module Language.Futhark.TypeChecker.Constraints + ( Type, + Ct (..), + Constraints, + TyVarInfo (..), + TyVar, + TyVars, + solve, + ) +where + +import Control.Monad.Except +import Control.Monad.State +import Data.Bifunctor +import Data.Map qualified as M +import Data.Text qualified as T +import Futhark.Util.Pretty +import Language.Futhark + +type Type = TypeBase () NoUniqueness + +data Ct = CtEq Type Type + deriving (Show) + +instance Pretty Ct where + pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 + +type Constraints = [Ct] + +-- | Information about a type variable. +data TyVarInfo + = -- | Can be substituted with anything. + TyVarFree + | -- | Can only be substituted with these primitive types. + TyVarPrim [PrimType] + | -- | Must be a record with these fields. + TyVarRecord (M.Map Name Type) + | -- | Must be a sum type with these fields. + TyVarSum (M.Map Name [Type]) + deriving (Show) + +instance Pretty TyVarInfo where + pretty TyVarFree = "free" + pretty (TyVarPrim pts) = "∈" <+> pretty pts + pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs + +type TyVar = VName + +-- | If a VName is not in this map, it is assumed to be rigid. +type TyVars = M.Map TyVar TyVarInfo + +data TyVarSol + = -- | Has been substituted with this. + TyVarSol Type + | -- | Not substituted yet; has this constraint. + TyVarUnsol TyVarInfo + deriving (Show) + +newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} + +initialState :: TyVars -> SolverState +initialState tyvars = SolverState $ M.map TyVarUnsol tyvars + +solution :: SolverState -> M.Map TyVar Type +solution = undefined + +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) + +solve :: Constraints -> TyVars -> Either T.Text (M.Map TyVar Type) +solve constraints tyvars = + second solution + . runExcept + . flip execStateT (initialState tyvars) + . runSolveM + $ throwError "cannot solve" +{-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 036338bc1d..eddc209b1a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -56,6 +56,7 @@ import Futhark.MonadFreshNames hiding (newName) import Futhark.Util (mapAccumLM) import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types @@ -76,45 +77,12 @@ data ValBinding | EqualityF deriving (Show) -type Type = TypeBase () NoUniqueness - toType :: TypeBase d u -> Type toType = bimap (const ()) (const NoUniqueness) expType :: Exp -> Type expType = toType . typeOf -data Ct = CtEq Type Type - deriving (Show) - -instance Pretty Ct where - pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - -type Constraints = [Ct] - --- | Information about a type variable. -data TyVarInfo - = -- | Can be substituted with anything. - TyVarFree - | -- | Can only be substituted with these primitive types. - TyVarPrim [PrimType] - | -- | Must be a record with these fields. - TyVarRecord (M.Map Name Type) - | -- | Must be a sum type with these fields. - TyVarSum (M.Map Name [Type]) - deriving (Show) - -instance Pretty TyVarInfo where - pretty TyVarFree = "free" - pretty (TyVarPrim pts) = "∈" <+> pretty pts - pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs - pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs - -type TyVar = VName - --- | If a VName is not in this map, it is assumed to be rigid. -type TyVars = M.Map TyVar TyVarInfo - data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, scopeTypeTable :: M.Map VName TypeBinding, @@ -270,6 +238,18 @@ newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBa newTypeOverloaded loc name pts = tyVarType <$> newTyVarWith loc name (TyVarPrim pts) +asStructType :: (Monoid u) => SrcLoc -> TypeBase d u -> TermM (TypeBase Size u) +asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] +asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do + t1' <- asStructType loc t1 + t2' <- asStructType loc t2 + pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' +asStructType loc t = do + t' <- newType loc "artificial" + ctEq t' t + pure t' + addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} @@ -613,13 +593,16 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM TyVar +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM Type +checkApply loc (fname, _) (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do + ctEq a $ expType arg + pure $ toType b checkApply loc (fname, _) ftype arg = do - a <- newType loc "a" - b <- newTyVar loc "b" + a <- newType loc "arg" + b <- newTyVar loc "res" ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) ctEq a (expType arg) - pure b + pure $ tyVarType b checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex @@ -794,9 +777,9 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - ((_, rt), args') <- mapAccumLM onArg (0, typeOf fe') args - - pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt [] + ((_, rt), args') <- mapAccumLM onArg (0, expType fe') args + rt' <- asStructType loc rt + pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] where fname = case fe of @@ -807,7 +790,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do arg' <- checkExp arg rt <- checkApply loc (fname, i) (toType f_t) arg' pure - ( (i + 1, tyVarType rt), + ( (i + 1, rt), (Info Nothing, arg') ) -- @@ -817,17 +800,19 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e2' <- checkExp e2 rt1 <- checkApply loc (Just op, 0) (toType ftype) e1' - rt2 <- checkApply loc (Just op, 1) (tyVarType rt1) e2' + rt2 <- checkApply loc (Just op, 1) rt1 e2' + rt2' <- asStructType loc rt2 pure $ AppExp (BinOp (op, oploc) (Info ftype) (e1', Info Nothing) (e2', Info Nothing) loc) - (Info (AppRes (tyVarType rt2) [])) + (Info (AppRes rt2' [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e rt <- checkApply loc (Just op, 0) (toType optype) e' + rt' <- asStructType loc rt pure $ OpSectionLeft op @@ -837,7 +822,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do ( Info (Unnamed, Scalar $ Prim Bool, Nothing), Info (Unnamed, Scalar $ Prim Bool) ) - (Info (RetType [] (tyVarType rt)), Info []) + (Info (RetType [] $ toRes Nonunique rt'), Info []) loc -- checkExp (OpSectionRight op _ e _ NoInfo loc) = do @@ -1056,10 +1041,12 @@ checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars traceM $ unlines - [ "function " <> prettyNameString fname, - "constraints:", + [ "# function " <> prettyNameString fname, + "## constraints:", unlines $ map prettyString cts, - "tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## solution:", + either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) $ solve cts tyvars ] pure (undefined, params', undefined, undefined, body') From 2b08f83959bd47f32f8452104a76cfa6dde137b6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 7 Feb 2024 13:52:03 +0100 Subject: [PATCH 005/296] Comment for Robert. --- src/Language/Futhark/TypeChecker/Constraints.hs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e33623704d..dfcb2d3241 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -17,7 +17,15 @@ import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark -type Type = TypeBase () NoUniqueness +-- | A shape component is currently just unit. The rank of an array is +-- then just the number of shape components it contains in its shape +-- list. When we add AUTOMAP, these components will also allow shape +-- variables. The list of components should then be understood as +-- concatenation of shapes (meaning you can't just take the length to +-- determine the rank of the array). +type SComp = () + +type Type = TypeBase SComp NoUniqueness data Ct = CtEq Type Type deriving (Show) From 0399a6d5ee9b7cbfc8e40481bb3473de43dbebe8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 7 Feb 2024 14:46:00 +0100 Subject: [PATCH 006/296] WIP in solving. --- .../Futhark/TypeChecker/Constraints.hs | 69 ++++++++++++++++++- src/Language/Futhark/TypeChecker/Types.hs | 2 +- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index dfcb2d3241..f50e570e8f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -61,6 +61,8 @@ type TyVars = M.Map TyVar TyVarInfo data TyVarSol = -- | Has been substituted with this. TyVarSol Type + | -- | Replaced by this other type variable. + TyVarLink VName | -- | Not substituted yet; has this constraint. TyVarUnsol TyVarInfo deriving (Show) @@ -70,17 +72,80 @@ newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} initialState :: TyVars -> SolverState initialState tyvars = SolverState $ M.map TyVarUnsol tyvars +substTyVars :: (Monoid u) => M.Map TyVar TyVarSol -> TypeBase SComp u -> TypeBase SComp u +substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = + case M.lookup v m of + Just (TyVarLink v') -> + substTyVars m $ Scalar $ TypeVar u (QualName qs v') args + Just (TyVarSol t') -> second (const mempty) t' + Just (TyVarUnsol _) -> t + Nothing -> t +substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt +substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs +substTyVars m (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars m) cs +substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = + Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 +substTyVars m (Array u shape elemt) = + arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt + solution :: SolverState -> M.Map TyVar Type -solution = undefined +solution s = M.mapMaybe f $ solverTyVars s + where + f (TyVarSol t) = Just $ substTyVars (solverTyVars s) t + f (TyVarLink v) = f =<< M.lookup v (solverTyVars s) + f (TyVarUnsol _) = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) +subTyVar :: VName -> Type -> SolveM () +subTyVar v t = + modify $ \s -> s {solverTyVars = M.insert v (TyVarSol t) $ solverTyVars s} + +linkTyVar :: VName -> VName -> SolveM () +linkTyVar v t = + modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} + +unify :: Type -> Type -> Maybe [(Type, Type)] +unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) + | pt1 == pt2 = Just [] +unify _ _ = Nothing + +solveCt :: Ct -> SolveM () +solveCt ct = do + let CtEq t1 t2 = ct + solveCt' (t1, t2) + where + bad = throwError $ "Unsolvable: " <> prettyText ct + solveCt' (t1, t2) = do + tyvars <- gets solverTyVars + let flexible v = case M.lookup v tyvars of + Just (TyVarLink v') -> flexible v' + Just (TyVarUnsol _) -> True + Just (TyVarSol _) -> False + Nothing -> False + case (t1, t2) of + ( Scalar (TypeVar _ (QualName [] v1) []), + Scalar (TypeVar _ (QualName [] v2) []) + ) -> + case (flexible v1, flexible v2) of + (False, False) -> bad + (True, False) -> subTyVar v1 t2 + (False, True) -> subTyVar v2 t1 + (True, True) -> linkTyVar v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), _) -> + if flexible v1 then subTyVar v1 t2 else bad + (_, Scalar (TypeVar _ (QualName [] v2) [])) -> + if flexible v2 then subTyVar v2 t1 else bad + _ -> case unify t1 t2 of + Nothing -> bad + Just eqs -> mapM_ solveCt' eqs + solve :: Constraints -> TyVars -> Either T.Text (M.Map TyVar Type) solve constraints tyvars = second solution . runExcept . flip execStateT (initialState tyvars) . runSolveM - $ throwError "cannot solve" + $ mapM solveCt constraints {-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 5c8673114b..5a364464b0 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -472,7 +472,7 @@ substTypesRet lookupSubst ot = onType (Array u shape et) = arrayOfWithAliases u (applySubst lookupSubst' shape) - <$> onType (second (const mempty) $ Scalar et) + <$> onType (Scalar et) onType (Scalar (Prim t)) = pure $ Scalar $ Prim t onType (Scalar (TypeVar u v targs)) = do targs' <- mapM subsTypeArg targs From c03d00d73d673a92f9db06c8f219ee8e08323d06 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 7 Feb 2024 16:55:56 +0100 Subject: [PATCH 007/296] Basically functional solver. --- src/Language/Futhark/Pretty.hs | 7 +- .../Futhark/TypeChecker/Constraints.hs | 59 +++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 72 ++++++++++--------- 3 files changed, 89 insertions(+), 49 deletions(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 4c91f277f7..c5134a4acd 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -7,6 +7,7 @@ module Language.Futhark.Pretty prettyTuple, leadingOperator, IsName (..), + prettyNameText, prettyNameString, Annot (..), ) @@ -55,9 +56,13 @@ instance IsName Name where prettyName = pretty toName = id +-- | Prettyprint name as text. +prettyNameText :: (IsName v) => v -> T.Text +prettyNameText = docText . prettyName + -- | Prettyprint name as string. Only use this for debugging. prettyNameString :: (IsName v) => v -> String -prettyNameString = T.unpack . docText . prettyName +prettyNameString = T.unpack . prettyNameText -- | Class for type constructors that represent annotations. Used in -- the prettyprinter to either print the original AST, or the computed diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index f50e570e8f..74a2dd2fd0 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,6 @@ module Language.Futhark.TypeChecker.Constraints ( Type, + toType, Ct (..), Constraints, TyVarInfo (..), @@ -25,8 +26,13 @@ import Language.Futhark -- determine the rank of the array). type SComp = () +-- | The type representation used by the constraint solver. Agnostic +-- to sizes. type Type = TypeBase SComp NoUniqueness +toType :: TypeBase d u -> Type +toType = bimap (const ()) (const NoUniqueness) + data Ct = CtEq Type Type deriving (Show) @@ -106,9 +112,26 @@ linkTyVar :: VName -> VName -> SolveM () linkTyVar v t = modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} +-- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Just [] +unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = + Just [(t1a, t2a), (toType t1r, toType t2r)] +unify (Scalar (Record fs1)) (Scalar (Record fs2)) + | M.keys fs1 == M.keys fs2 = + Just $ M.elems $ M.intersectionWith (,) fs1 fs2 +unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) + | M.keys cs1 == M.keys cs2 = do + fmap concat + . forM (M.elems $ M.intersectionWith (,) cs1 cs2) + $ \(ts1, ts2) -> do + guard $ length ts1 == length ts2 + Just $ zip ts1 ts2 +unify t1 t2 + | Just t1' <- peelArray 1 t1, + Just t2' <- peelArray 1 t2 = + Just [(t1', t2')] unify _ _ = Nothing solveCt :: Ct -> SolveM () @@ -124,20 +147,28 @@ solveCt ct = do Just (TyVarUnsol _) -> True Just (TyVarSol _) -> False Nothing -> False - case (t1, t2) of - ( Scalar (TypeVar _ (QualName [] v1) []), - Scalar (TypeVar _ (QualName [] v2) []) - ) -> - case (flexible v1, flexible v2) of - (False, False) -> bad - (True, False) -> subTyVar v1 t2 - (False, True) -> subTyVar v2 t1 - (True, True) -> linkTyVar v1 v2 - (Scalar (TypeVar _ (QualName [] v1) []), _) -> - if flexible v1 then subTyVar v1 t2 else bad - (_, Scalar (TypeVar _ (QualName [] v2) [])) -> - if flexible v2 then subTyVar v2 t1 else bad - _ -> case unify t1 t2 of + sub t@(Scalar (TypeVar u (QualName [] v) [])) = + case M.lookup v tyvars of + Just (TyVarLink v') -> sub $ Scalar (TypeVar u (QualName [] v') []) + Just (TyVarSol t') -> sub t' + _ -> t + sub t = t + case (sub t1, sub t2) of + ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), + t2'@(Scalar (TypeVar _ (QualName [] v2) [])) + ) + | v1 == v2 -> pure () + | otherwise -> + case (flexible v1, flexible v2) of + (False, False) -> bad + (True, False) -> subTyVar v1 t2' + (False, True) -> subTyVar v2 t1' + (True, True) -> linkTyVar v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), t2') -> + if flexible v1 then subTyVar v1 t2' else bad + (t1', Scalar (TypeVar _ (QualName [] v2) [])) -> + if flexible v2 then subTyVar v2 t1' else bad + (t1', t2') -> case unify t1' t2' of Nothing -> bad Just eqs -> mapM_ solveCt' eqs diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index eddc209b1a..b1b83f6f2b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -77,9 +77,6 @@ data ValBinding | EqualityF deriving (Show) -toType :: TypeBase d u -> Type -toType = bimap (const ()) (const NoUniqueness) - expType :: Exp -> Type expType = toType . typeOf @@ -254,7 +251,13 @@ addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () -ctEq t1 t2 = addCt $ CtEq (toType t1) (toType t2) +ctEq t1 t2 = + -- As a minor optimisation, do not add constraint if the types are + -- equal. + unless (t1' == t2') $ addCt $ CtEq t1' t2' + where + t1' = toType t1 + t2' = toType t2 localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -287,7 +290,7 @@ instance MonadError TypeError TermM where f' (_, e) = let TermM m' = f e in m' instance MonadTypeChecker TermM where - checkExpForSize = checkExp + checkExpForSize = require "use as size" [Signed Int64] <=< checkExp warnings ws = modify $ \s -> s {termWarnings = termWarnings s <> ws} @@ -325,6 +328,10 @@ instance MonadTypeChecker TermM where --- All the general machinery goes above. +arrayOfRank :: Int -> Type -> Type +arrayOfRank 0 t = t +arrayOfRank n t = arrayOf (Shape $ replicate n ()) t + require :: T.Text -> [PrimType] -> Exp -> TermM Exp require why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts @@ -365,8 +372,8 @@ instTypeScheme qn loc tparams t = do TypeParamType x _ _ -> do i <- incCounter let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) [])) + v <- newTyVar loc name + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v)) TypeParamDim {} -> pure Nothing let t' = applySubst (`lookup` substs) t @@ -394,13 +401,7 @@ lookupVar loc qn@(QualName qs name) = do pure $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newType loc "t" - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded (argtype :: StructType) pts rt @@ -811,34 +812,37 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - rt <- checkApply loc (Just op, 0) (toType optype) e' - rt' <- asStructType loc rt + void $ checkApply loc (Just op, 0) (toType optype) e' + let t1 = typeOf e' + t2 <- newType loc "t2" + rt <- newType loc "rt" + ctEq optype $ foldFunType [toParam Observe t1, t2] $ RetType [] rt pure $ OpSectionLeft op (Info optype) e' - -- Dummy types. - ( Info (Unnamed, Scalar $ Prim Bool, Nothing), - Info (Unnamed, Scalar $ Prim Bool) + ( Info (Unnamed, toParam Observe t1, Nothing), + Info (Unnamed, t2) ) - (Info (RetType [] $ toRes Nonunique rt'), Info []) + (Info (RetType [] rt), Info []) loc -- checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e t1 <- newType loc "t" + let t2 = typeOf e' rt <- newType loc "rt" - ctEq optype $ foldFunType [t1, toParam Observe $ typeOf e'] $ RetType [] rt + ctEq optype $ foldFunType [t1, toParam Observe t2] $ RetType [] rt pure $ OpSectionRight op (Info optype) e' -- Dummy types. - ( Info (Unnamed, Scalar $ Prim Bool), - Info (Unnamed, Scalar $ Prim Bool, Nothing) + ( Info (Unnamed, toParam Observe t1), + Info (Unnamed, toParam Observe t2, Nothing) ) (Info $ RetType [] rt) loc @@ -901,7 +905,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do maybe_step' <- traverse checkExp' maybe_step end' <- traverse checkExp' end range_t <- newType loc "range" - ctEq range_t $ arrayOf (Shape [()]) (toType (typeOf start')) + ctEq range_t $ arrayOfRank 1 (toType (typeOf start')) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] where checkExp' = require "use in range expression" anyIntType <=< checkExp @@ -925,8 +929,8 @@ checkExp (IndexSection slice NoInfo loc) = do index_elem_t <- newType loc "index_elem" index_res_t <- newType loc "index_res" let num_slices = length $ filter isSlice slice - ctEq index_arg_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t - ctEq index_res_t $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + ctEq index_arg_t $ arrayOfRank num_slices index_elem_t + ctEq index_res_t $ arrayOfRank (length slice) index_elem_t let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ RetType [] index_res_t pure $ IndexSection slice' (Info ft) loc -- @@ -936,8 +940,8 @@ checkExp (AppExp (Index e slice loc) _) = do index_t <- newType loc "index" index_elem_t <- newType loc "index_elem" let num_slices = length $ filter isSlice slice - ctEq index_t $ arrayOf (Shape (replicate num_slices ())) index_elem_t - ctEq (typeOf e') $ arrayOf (Shape (replicate (length slice) ())) index_elem_t + ctEq index_t $ arrayOfRank num_slices index_elem_t + ctEq (typeOf e') $ arrayOfRank (length slice) index_elem_t pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) -- checkExp (Update src slice ve loc) = do @@ -946,8 +950,8 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq (typeOf src') $ arrayOf (Shape (replicate (length slice) ())) update_elem_t - ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + ctEq (typeOf src') $ arrayOfRank (length slice) update_elem_t + ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do @@ -958,8 +962,8 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq src_t $ arrayOf (Shape (replicate (length slice) ())) update_elem_t - ctEq (typeOf ve') $ arrayOf (Shape (replicate num_slices ())) update_elem_t + ctEq src_t $ arrayOfRank (length slice) update_elem_t + ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) @@ -972,7 +976,7 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do ctEq (typeOf e1') (Scalar (Prim Bool) :: Type) ctEq (typeOf e2') (typeOf e3') - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e1') []) + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) -- checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e @@ -999,7 +1003,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" - ctEq (typeOf arr') $ arrayOf (Shape [()]) (toType elem_t) + ctEq (typeOf arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') From 0cd1562bf8192500669df9f36b70a3f76acbcf4d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 15:49:09 +0100 Subject: [PATCH 008/296] Check return ascriptions. --- src/Language/Futhark/TypeChecker/Terms2.hs | 39 +++++++++++++--------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b1b83f6f2b..49e2b9ae99 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -370,7 +370,6 @@ instTypeScheme qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do case tparam of TypeParamType x _ _ -> do - i <- incCounter let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) v <- newTyVar loc name pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v)) @@ -686,6 +685,13 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) +checkRetDecl :: Exp -> Maybe (TypeExp NoInfo VName) -> TermM (Maybe (TypeExp Info VName)) +checkRetDecl _ Nothing = pure Nothing +checkRetDecl body (Just te) = do + (te', _, RetType _ st, _) <- checkTypeExp te + ctEq (typeOf body) st + pure $ Just te' + checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- checkExp (Var qn _ loc) = do @@ -854,17 +860,12 @@ checkExp (ProjectSection fields NoInfo loc) = do let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b pure $ ProjectSection fields (Info ft) loc -- -checkExp (Lambda params body rettype NoInfo loc) = do +checkExp (Lambda params body retdecl NoInfo loc) = do bindParams [] params $ \params' -> do body' <- checkExp body - rettype_te' <- case rettype of - Just rettype_te -> do - (rettype_te', _, RetType _ st, _) <- checkTypeExp rettype_te - ctEq (typeOf body') st - pure $ Just rettype_te' - Nothing -> pure Nothing + retdecl' <- checkRetDecl body' retdecl let ret = RetType [] $ toRes Nonunique $ typeOf body' - pure $ Lambda params' body' rettype_te' (Info ret) loc + pure $ Lambda params' body' retdecl' (Info ret) loc -- checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -876,12 +877,13 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = do (LetPat sizes (fmap toStruct pat') e' body' loc) (Info $ AppRes (typeOf body') []) -- -checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do - (tparams', params', maybe_retdecl', rettype, e') <- +checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) = do + (tparams', params', retdecl', rettype, e') <- bindParams tparams params $ \params' -> do e' <- checkExp e let ret = RetType [] $ toRes Nonunique $ typeOf e' - pure (tparams, params', undefined, ret, e') + retdecl' <- checkRetDecl e' retdecl + pure (tparams, params', retdecl', ret, e') let entry = BoundV tparams' $ funType params' rettype bindF scope = @@ -894,7 +896,7 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l AppExp ( LetFun name - (tparams', params', maybe_retdecl', Info rettype, e') + (tparams', params', retdecl', Info rettype, e') body' loc ) @@ -1035,14 +1037,19 @@ checkValDef :: ( [TypeParam], [Pat ParamType], Maybe (TypeExp Info VName), - ResRetType, Exp ) -checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do +checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bindParams tparams params $ \params' -> do body' <- checkExp body + + retdecl' <- checkRetDecl body' retdecl + cts <- gets termConstraints tyvars <- gets termTyVars + + let solution = solve cts tyvars + traceM $ unlines [ "# function " <> prettyNameString fname, @@ -1053,4 +1060,4 @@ checkValDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermM $ do "## solution:", either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) $ solve cts tyvars ] - pure (undefined, params', undefined, undefined, body') + pure (undefined, params', retdecl', body') From 7de2678e7c7bf8d3fdd1ee8ec76724db757ac828 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 15:53:28 +0100 Subject: [PATCH 009/296] Clean up things. --- src/Language/Futhark/TypeChecker/Terms2.hs | 25 +++------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 49e2b9ae99..ee1b04da01 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -217,7 +217,7 @@ newTyVarWith loc desc info = do modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} pure v -newTyVar :: a -> Name -> TermM TyVar +newTyVar :: (Located loc) => loc -> Name -> TermM TyVar newTyVar loc desc = newTyVarWith loc desc TyVarFree newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) @@ -338,25 +338,6 @@ require why pts e = do ctEq t $ expType e pure e --- | Create a new type name and insert it (unconstrained) in the set --- of type variables. -instTypeParam :: - (Monoid as) => - QualName VName -> - SrcLoc -> - TypeParam -> - TermM (VName, Subst (RetTypeBase dim as)) -instTypeParam qn loc tparam = do - i <- incCounter - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - case tparam of - TypeParamType {} -> do - modify $ \s -> s {termTyVars = M.insert v TyVarFree $ termTyVars s} - pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) - TypeParamDim {} -> - pure (v, ExpSubst $ sizeFromName (qualName v) loc) - -- | Instantiate a type scheme with fresh type variables for its type -- parameters. Returns the names of the fresh type variables, the -- instance list, and the instantiated type. @@ -481,7 +462,7 @@ checkPat' (RecordPat fs loc) NoneInferred = <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) <*> pure loc checkPat' (PatAscription p t loc) maybe_outer_t = do - (t', _, RetType dims st, _) <- checkTypeExp t + (t', _, RetType _ st, _) <- checkTypeExp t case maybe_outer_t of Ascribed outer_t -> do @@ -1058,6 +1039,6 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, "## solution:", - either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) $ solve cts tyvars + either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) solution ] pure (undefined, params', retdecl', body') From 51eafc66331a76de9eac11831ccc367a9695524b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 15:56:07 +0100 Subject: [PATCH 010/296] Remove most warnings. --- src/Language/Futhark/TypeChecker/Terms2.hs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ee1b04da01..ee62ed478f 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -210,8 +210,8 @@ incCounter = do tyVarType :: (Monoid u) => TyVar -> TypeBase dim u tyVarType v = Scalar $ TypeVar mempty (qualName v) [] -newTyVarWith :: a -> Name -> TyVarInfo -> TermM TyVar -newTyVarWith loc desc info = do +newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar +newTyVarWith _loc desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} @@ -333,7 +333,7 @@ arrayOfRank 0 t = t arrayOfRank n t = arrayOf (Shape $ replicate n ()) t require :: T.Text -> [PrimType] -> Exp -> TermM Exp -require why pts e = do +require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts ctEq t $ expType e pure e @@ -347,13 +347,12 @@ instTypeScheme :: [TypeParam] -> StructType -> TermM ([VName], StructType) -instTypeScheme qn loc tparams t = do +instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do case tparam of - TypeParamType x _ _ -> do - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newTyVar loc name - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v)) + TypeParamType _ v _ -> do + v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) TypeParamDim {} -> pure Nothing let t' = applySubst (`lookup` substs) t @@ -575,10 +574,10 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM Type -checkApply loc (fname, _) (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do +checkApply _ _ (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do ctEq a $ expType arg pure $ toType b -checkApply loc (fname, _) ftype arg = do +checkApply loc _ ftype arg = do a <- newType loc "arg" b <- newTyVar loc "res" ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) @@ -1020,7 +1019,7 @@ checkValDef :: Maybe (TypeExp Info VName), Exp ) -checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do +checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do bindParams tparams params $ \params' -> do body' <- checkExp body From 7b78c864a0cae2a4798f37d9fdbf6181584e7e0b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 18:27:26 +0100 Subject: [PATCH 011/296] Defective type checker work. --- src/Language/Futhark/TypeChecker.hs | 3 - src/Language/Futhark/TypeChecker/Terms.hs | 266 ++++++++---------- .../Futhark/TypeChecker/Terms/Loop.hs | 26 +- .../Futhark/TypeChecker/Terms/Monad.hs | 24 +- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 85 +----- src/Language/Futhark/TypeChecker/Terms2.hs | 13 +- 6 files changed, 167 insertions(+), 250 deletions(-) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 23e02b79d7..9bc29be7d4 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -34,7 +34,6 @@ import Language.Futhark.TypeChecker.Modules import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Names import Language.Futhark.TypeChecker.Terms -import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Prelude hiding (abs, mod) @@ -696,8 +695,6 @@ checkValBind vb = do attrs' <- mapM checkAttr attrs - void $ Terms2.checkValDef (fname, maybe_tdecl, tparams, params, body, loc) - (tparams', params', maybe_tdecl', rettype, body') <- checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d933287ab3..3e6ecbff19 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -38,6 +38,7 @@ import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Terms.Loop import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Terms.Pat +import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) @@ -182,8 +183,8 @@ sliceShape _ _ t = pure (t, []) checkAscript :: SrcLoc -> - TypeExp (ExpBase NoInfo VName) VName -> - ExpBase NoInfo VName -> + TypeExp Exp VName -> + Exp -> TermTypeM (TypeExp Exp VName, Exp) checkAscript loc te e = do (te', decl_t, _) <- checkTypeExpNonrigid te @@ -197,8 +198,8 @@ checkAscript loc te e = do checkCoerce :: SrcLoc -> - TypeExp (ExpBase NoInfo VName) VName -> - ExpBase NoInfo VName -> + TypeExp Exp VName -> + Exp -> TermTypeM (TypeExp Exp VName, StructType, Exp) checkCoerce loc te e = do (te', te_t, ext) <- checkTypeExpNonrigid te @@ -347,48 +348,26 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp -checkExp :: ExpBase NoInfo VName -> TermTypeM Exp +checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc -checkExp (Hole _ loc) = do - t <- newTypeVar loc "t" - pure $ Hole (Info t) loc +checkExp (Hole info loc) = + pure $ Hole info loc checkExp (StringLit vs loc) = pure $ StringLit vs loc -checkExp (IntLit val NoInfo loc) = do - t <- newTypeVar loc "t" - mustBeOneOf anyNumberType (mkUsage loc "integer literal") t - pure $ IntLit val (Info t) loc -checkExp (FloatLit val NoInfo loc) = do - t <- newTypeVar loc "t" - mustBeOneOf anyFloatType (mkUsage loc "float literal") t - pure $ FloatLit val (Info t) loc +checkExp (IntLit val info loc) = + pure $ IntLit val info loc +checkExp (FloatLit val info loc) = + pure $ FloatLit val info loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc checkExp (RecordLit fs loc) = - RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc + RecordLit <$> mapM checkField fs <*> pure loc where - checkField (RecordFieldExplicit f e rloc) = do - errIfAlreadySet f rloc - modify $ M.insert f rloc - RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc - checkField (RecordFieldImplicit name NoInfo rloc) = do - errIfAlreadySet (baseName name) rloc - t <- lift $ lookupVar rloc $ qualName name - modify $ M.insert (baseName name) rloc - pure $ RecordFieldImplicit name (Info t) rloc - - errIfAlreadySet f rloc = do - maybe_sloc <- gets $ M.lookup f - case maybe_sloc of - Just sloc -> - lift . typeError rloc mempty $ - "Field" - <+> dquotes (pretty f) - <+> "previously defined at" - <+> pretty (locStrRel rloc sloc) - <> "." - Nothing -> pure () + checkField (RecordFieldExplicit f e rloc) = + RecordFieldExplicit f <$> checkExp e <*> pure rloc + checkField (RecordFieldImplicit name info rloc) = + pure $ RecordFieldImplicit name info rloc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use @@ -484,12 +463,12 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do checkExp (Ascript e te loc) = do (te', e') <- checkAscript loc te e pure $ Ascript e' te' loc -checkExp (Coerce e te NoInfo loc) = do +checkExp (Coerce e te _ loc) = do (te', te_t, e') <- checkCoerce loc te e t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do +checkExp (AppExp (BinOp (op, oploc) _ (e1, _) (e2, _) loc) _) = do ftype <- lookupVar oploc op e1' <- checkExp e1 e2' <- checkExp e2 @@ -509,7 +488,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do loc ) (Info (AppRes rt' retext)) -checkExp (Project k e NoInfo loc) = do +checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' kt <- mustHaveField (mkUsage loc $ docText $ "projection of field " <> dquotes (pretty k)) k t @@ -543,7 +522,7 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn NoInfo loc) = do +checkExp (Var qn _ loc) = do t <- lookupVar loc qn pure $ Var qn (Info t) loc checkExp (Negate arg loc) = do @@ -552,7 +531,7 @@ checkExp (Negate arg loc) = do checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc -checkExp (AppExp (Apply fe args loc) NoInfo) = do +checkExp (AppExp (Apply fe args loc) _) = do fe' <- checkExp fe args' <- mapM (checkExp . snd) args t <- expType fe' @@ -598,7 +577,7 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = do AppExp (LetPat sizes (fmap toStruct pat') e' body' loc) (Info $ AppRes body_t' retext) -checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do +checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _) = do (tparams', params', maybe_retdecl', rettype, e') <- checkBinding (name, maybe_retdecl, tparams, params, e, loc) @@ -621,19 +600,18 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do - src' <- checkIdent src slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' - unify (mkUsage loc "type of target array") t $ unInfo $ identType src' + unify (mkUsage loc "type of target array") t $ unInfo $ identType src (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve - bindingIdent dest (unInfo (identType src')) $ \dest' -> do + bindingIdent dest $ do body' <- checkExp body - (body_t, ext) <- unscopeType loc [identName dest'] =<< expTypeFully body' - pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) + (body_t, ext) <- unscopeType loc [identName dest] =<< expTypeFully body' + pure $ AppExp (LetWith dest src slice' ve' body' loc) (Info $ AppRes body_t ext) checkExp (Update src slice ve loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' src) "src" $ sliceDims slice' @@ -645,11 +623,9 @@ checkExp (Update src slice ve loc) = do -- Record updates are a bit hacky, because we do not have row typing -- (yet?). For now, we only permit record updates where we know the -- full type up to the field we are updating. -checkExp (RecordUpdate src fields ve NoInfo loc) = do +checkExp (RecordUpdate src fields ve _ loc) = do src' <- checkExp src ve' <- checkExp ve - a <- expTypeFully src' - foldM_ (flip $ mustHaveField usage) a fields ve_t <- expType ve' updated_t <- updateField fields ve_t =<< expTypeFully src' pure $ RecordUpdate src' fields ve' (Info updated_t) loc @@ -681,11 +657,11 @@ checkExp (AppExp (Index e slice loc) _) = do =<< expTypeFully e' pure $ AppExp (Index e' slice' loc) (Info $ AppRes t' retext) -checkExp (Assert e1 e2 NoInfo loc) = do +checkExp (Assert e1 e2 _ loc) = do e1' <- require "being asserted" [Bool] =<< checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc -checkExp (Lambda params body rettype_te NoInfo loc) = do +checkExp (Lambda params body rettype_te _ loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do rettype_checked <- traverse checkTypeExpNonrigid rettype_te @@ -757,7 +733,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op _ e _ NoInfo loc) = do +checkExp (OpSectionRight op _ e _ _ loc) = do ftype <- lookupVar loc op e' <- checkExp e case ftype of @@ -782,13 +758,13 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (ProjectSection fields NoInfo loc) = do +checkExp (ProjectSection fields _ loc) = do a <- newTypeVar loc "a" let usage = mkUsage loc "projection at" b <- foldM (flip $ mustHaveField usage) a fields let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ toRes Nonunique b pure $ ProjectSection fields (Info ft) loc -checkExp (IndexSection slice NoInfo loc) = do +checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t @@ -801,11 +777,9 @@ checkExp (AppExp (Loop _ mergepat mergeexp form loopbody loc) _) = do AppExp (Loop sparams mergepat' mergeexp' form' loopbody' loc) (Info appres) -checkExp (Constr name es NoInfo loc) = do +checkExp (Constr name es _ loc) = do t <- newTypeVar loc "t" es' <- mapM checkExp es - ets <- mapM expType es' - mustHaveConstr (mkUsage loc "use of constructor") name t ets pure $ Constr name es' (Info t) loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e @@ -819,9 +793,20 @@ checkExp (AppExp (Match e cs loc) _) = do checkExp (Attr info e loc) = Attr <$> checkAttr info <*> checkExp e <*> pure loc +checkCase :: + StructType -> + CaseBase Info VName -> + TermTypeM (CaseBase Info VName, StructType, [VName]) +checkCase mt (CasePat p e loc) = + bindingPat [] p mt $ \p' -> do + e' <- checkExp e + e_t <- expTypeFully e' + (e_t', retext) <- unscopeType loc (patNames p') e_t + pure (CasePat (fmap toStruct p') e' loc, e_t', retext) + checkCases :: StructType -> - NE.NonEmpty (CaseBase NoInfo VName) -> + NE.NonEmpty (CaseBase Info VName) -> TermTypeM (NE.NonEmpty (CaseBase Info VName), StructType, [VName]) checkCases mt rest_cs = case NE.uncons rest_cs of @@ -834,17 +819,6 @@ checkCases mt rest_cs = (brancht, retext) <- unifyBranchTypes (srclocOf c) c_t cs_t pure (NE.cons c' cs', brancht, retext) -checkCase :: - StructType -> - CaseBase NoInfo VName -> - TermTypeM (CaseBase Info VName, StructType, [VName]) -checkCase mt (CasePat p e loc) = - bindingPat [] p mt $ \p' -> do - e' <- checkExp e - e_t <- expTypeFully e' - (e_t', retext) <- unscopeType loc (patNames p') e_t - pure (CasePat (fmap toStruct p') e' loc, e_t', retext) - -- | An unmatched pattern. Used in in the generation of -- unmatched pattern warnings by the type checker. data Unmatched p @@ -873,12 +847,7 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) -checkIdent :: IdentBase NoInfo VName StructType -> TermTypeM (Ident StructType) -checkIdent (Ident name _ loc) = do - vt <- lookupVar loc $ qualName name - pure $ Ident name (Info vt) loc - -checkSlice :: SliceBase NoInfo VName -> TermTypeM [DimIndex] +checkSlice :: SliceBase Info VName -> TermTypeM [DimIndex] checkSlice = mapM checkDimIndex where checkDimIndex (DimFix i) = do @@ -1039,8 +1008,8 @@ checkApply loc (fname, prev_applied) ftype argexp = do -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp $ do - e' <- checkExp e +checkOneExp e = runTermTypeM (checkExp . undefined) $ do + e' <- checkExp $ undefined e let t = typeOf e' (tparams, _, _) <- letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t @@ -1053,8 +1022,8 @@ checkOneExp e = runTermTypeM checkExp $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp $ do - e' <- checkExp e +checkSizeExp e = runTermTypeM (checkExp . undefined) $ do + e' <- checkExp $ undefined e let t = typeOf e' when (hasBinding e') $ typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ @@ -1286,61 +1255,6 @@ localChecks = void . check <> pretty ty <> "." --- | Type-check a top-level (or module-level) function definition. --- Despite the name, this is also used for checking constant --- definitions, by treating them as 0-ary functions. -checkFunDef :: - ( VName, - Maybe (TypeExp (ExpBase NoInfo VName) VName), - [TypeParam], - [PatBase NoInfo VName ParamType], - ExpBase NoInfo VName, - SrcLoc - ) -> - TypeM - ( [TypeParam], - [Pat ParamType], - Maybe (TypeExp Exp VName), - ResRetType, - Exp - ) -checkFunDef (fname, maybe_retdecl, tparams, params, body, loc) = - runTermTypeM checkExp $ do - (tparams', params', maybe_retdecl', RetType dims rettype', body') <- - checkBinding (fname, maybe_retdecl, tparams, params, body, loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params' - - -- Then replace all inferred types in the body and parameters. - body'' <- updateTypes body' - params'' <- updateTypes params' - maybe_retdecl'' <- traverse updateTypes maybe_retdecl' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body'' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params' - localChecks body'' - - let ((body''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params'', - body'', - RetType dims rettype'', - maybe_retdecl'', - loc - ) - - mapM_ throwError errors - - pure (tparams', params'', maybe_retdecl'', updated_ret, body''') - -- | This is "fixing" as in "setting them", not "correcting them". We -- only make very conservative fixing. fixOverloadedTypes :: Names -> TermTypeM () @@ -1418,10 +1332,10 @@ inferredReturnType loc params t = do checkBinding :: ( VName, - Maybe (TypeExp (ExpBase NoInfo VName) VName), + Maybe (TypeExp Exp VName), [TypeParam], - [PatBase NoInfo VName ParamType], - ExpBase NoInfo VName, + [PatBase Info VName ParamType], + ExpBase Info VName, SrcLoc ) -> TermTypeM @@ -1670,7 +1584,7 @@ letGeneralise defname defloc tparams params restype = checkFunBody :: [Pat ParamType] -> - ExpBase NoInfo VName -> + Exp -> Maybe ResType -> SrcLoc -> TermTypeM Exp @@ -1705,3 +1619,73 @@ arrayOfM :: arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t + +addInitialConstraints :: M.Map VName (TypeBase () NoUniqueness) -> TermTypeM () +addInitialConstraints = mapM_ f . M.toList + where + addConstraint v c = modifyConstraints $ M.insert v (0, c) + usage = mkUsage (mempty :: Loc) "trust me bro" + f (v, t) = do + (t', _) <- allDimsFreshInType usage Nonrigid "dv" t + addConstraint v $ Constraint (RetType [] t') usage + +-- | Type-check a top-level (or module-level) function definition. +-- Despite the name, this is also used for checking constant +-- definitions, by treating them as 0-ary functions. +checkFunDef :: + ( VName, + Maybe (TypeExp (ExpBase NoInfo VName) VName), + [TypeParam], + [PatBase NoInfo VName ParamType], + ExpBase NoInfo VName, + SrcLoc + ) -> + TypeM + ( [TypeParam], + [Pat ParamType], + Maybe (TypeExp Exp VName), + ResRetType, + Exp + ) +checkFunDef (fname, retdecl, tparams, params, body, loc) = do + (maybe_tysubsts, params', retdecl', body') <- + Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + case maybe_tysubsts of + Left err -> typeError loc mempty $ pretty err + Right tysubsts -> runTermTypeM (checkExp . undefined) $ do + addInitialConstraints tysubsts + + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- updateTypes body'' + params''' <- updateTypes params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index c5e3619ac7..349b105823 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -102,7 +102,7 @@ wellTypedLoopArg src sparams pat arg = do -- | An un-checked loop. type UncheckedLoop = - (PatBase NoInfo VName ParamType, ExpBase NoInfo VName, LoopFormBase NoInfo VName, ExpBase NoInfo VName) + (Pat ParamType, Exp, LoopFormBase Info VName, Exp) -- | A loop that has been type-checked. type CheckedLoop = @@ -111,7 +111,7 @@ type CheckedLoop = -- | Type-check a @loop@ expression, passing in a function for -- type-checking subexpressions. checkLoop :: - (ExpBase NoInfo VName -> TermTypeM Exp) -> + (Exp -> TermTypeM Exp) -> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes) @@ -223,18 +223,16 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do uboundexp' <- require "being the bound in a 'for' loop" anySignedType =<< checkExp uboundexp - bound_t <- expTypeFully uboundexp' - bindingIdent i bound_t $ \i' -> - bindingPat [] mergepat merge_t $ - \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - For i' uboundexp', - loopbody' - ) + bindingIdent i . bindingPat [] mergepat merge_t $ + \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + For i uboundexp', + loopbody' + ) ForIn xpat e -> do (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index efa052fc7d..de3b03472d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -94,7 +94,7 @@ data Checking | CheckingAscription StructType StructType | CheckingLetGeneralise Name | CheckingParams (Maybe Name) - | CheckingPat (PatBase NoInfo VName StructType) (Inferred StructType) + | CheckingPat (PatBase Info VName StructType) (Inferred StructType) | CheckingLoopBody StructType StructType | CheckingLoopInitial StructType StructType | CheckingRecordUpdate [Name] StructType StructType @@ -544,8 +544,8 @@ allDimsFreshInType :: Usage -> Rigidity -> Name -> - TypeBase Size als -> - TermTypeM (TypeBase Size als, M.Map VName Size) + TypeBase d als -> + TermTypeM (TypeBase Size als, M.Map VName d) allDimsFreshInType usage r desc t = runStateT (bitraverse onDim pure t) mempty where @@ -581,25 +581,15 @@ require why ts e = do mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e pure e -termCheckTypeExp :: - TypeExp (ExpBase NoInfo VName) VName -> - TermTypeM (TypeExp Exp VName, [VName], ResRetType) -termCheckTypeExp te = do - (te', svars, rettype, _l) <- checkTypeExp te +checkTypeExpNonrigid :: TypeExp Exp VName -> TermTypeM (TypeExp Exp VName, ResType, [VName]) +checkTypeExpNonrigid te = do + (te', svars, rettype, _l) <- checkTypeExp $ undefined te -- No guarantee that the locally bound sizes in rettype are globally -- unique, but we want to turn them into size variables, so let's - -- give them some unique names. Maybe this should be done below, - -- where we actually turn these into size variables? + -- give them some unique names. RetType dims st <- renameRetType rettype - pure (te', svars, RetType dims st) - -checkTypeExpNonrigid :: - TypeExp (ExpBase NoInfo VName) VName -> - TermTypeM (TypeExp Exp VName, ResType, [VName]) -checkTypeExpNonrigid te = do - (te', svars, RetType dims st) <- termCheckTypeExp te forM_ (svars ++ dims) $ \v -> constrain v $ Size Nothing $ mkUsage (srclocOf te) "anonymous size in type expression" pure (te', st, svars ++ dims) diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 980f278326..ad4ea0aa0f 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -110,47 +110,21 @@ bindingSizes sizes m = binding (map sizeWithType sizes) m Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) -- | Bind a single term-level identifier. -bindingIdent :: - IdentBase NoInfo VName StructType -> - StructType -> - (Ident StructType -> TermTypeM a) -> - TermTypeM a -bindingIdent (Ident v NoInfo vloc) t m = do - let ident = Ident v (Info t) vloc - binding [ident] $ m ident - --- All this complexity is just so we can handle un-suffixed numeric --- literals in patterns. -patLitMkType :: PatLit -> SrcLoc -> TermTypeM ParamType -patLitMkType (PatLitInt _) loc = do - t <- newTypeVar loc "t" - mustBeOneOf anyNumberType (mkUsage loc "integer literal") (toStruct t) - pure t -patLitMkType (PatLitFloat _) loc = do - t <- newTypeVar loc "t" - mustBeOneOf anyFloatType (mkUsage loc "float literal") (toStruct t) - pure t -patLitMkType (PatLitPrim v) _ = - pure $ Scalar $ Prim $ primValueType v +bindingIdent :: Ident StructType -> TermTypeM a -> TermTypeM a +bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - PatBase NoInfo VName ParamType -> + Pat ParamType -> Inferred ParamType -> TermTypeM (Pat ParamType) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc -checkPat' _ (Id name NoInfo loc) (Ascribed t) = - pure $ Id name (Info t) loc -checkPat' _ (Id name NoInfo loc) NoneInferred = do - t <- newTypeVar loc "t" +checkPat' _ (Id name (Info t) loc) _ = pure $ Id name (Info t) loc -checkPat' _ (Wildcard _ loc) (Ascribed t) = - pure $ Wildcard (Info t) loc -checkPat' _ (Wildcard NoInfo loc) NoneInferred = do - t <- newTypeVar loc "t" +checkPat' _ (Wildcard (Info t) loc) _ = pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, @@ -192,6 +166,9 @@ checkPat' sizes (RecordPat fs loc) NoneInferred = RecordPat . M.toList <$> traverse (\p -> checkPat' sizes p NoneInferred) (M.fromList fs) <*> pure loc +checkPat' sizes (PatAscription p t loc) _ = + -- FIXME + PatAscription <$> checkPat' sizes p NoneInferred <*> pure t <*> pure loc checkPat' sizes (PatAscription p t loc) maybe_outer_t = do (t', st, _) <- checkTypeExpNonrigid t @@ -209,47 +186,15 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc -checkPat' _ (PatLit l NoInfo loc) (Ascribed t) = do - t' <- patLitMkType l loc - unify (mkUsage loc "matching against literal") (toStruct t') (toStruct t) - pure $ PatLit l (Info t') loc -checkPat' _ (PatLit l NoInfo loc) NoneInferred = do - t' <- patLitMkType l loc - pure $ PatLit l (Info t') loc -checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) - | Just ts <- M.lookup n cs = do - when (length ps /= length ts) $ - typeError loc mempty $ - "Pattern #" - <> pretty n - <> " expects" - <+> pretty (length ps) - <+> "constructor arguments, but type provides" - <+> pretty (length ts) - <+> "arguments." - ps' <- zipWithM (checkPat' sizes) ps $ map Ascribed ts - pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc -checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed t) = do - t' <- newTypeVar loc "t" - ps' <- forM ps $ \p -> do - p_t <- newTypeVar (srclocOf p) "t" - checkPat' sizes p $ Ascribed p_t - mustHaveConstr usage n (toStruct t') (patternStructType <$> ps') - unify usage t' (toStruct t) - pure $ PatConstr n (Info t) ps' loc - where - usage = mkUsage loc "matching against constructor" -checkPat' sizes (PatConstr n NoInfo ps loc) NoneInferred = do +checkPat' _ (PatLit l info loc) _ = + pure $ PatLit l info loc +checkPat' sizes (PatConstr n info ps loc) _ = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps - t <- newTypeVar loc "t" - mustHaveConstr usage n (toStruct t) (patternStructType <$> ps') - pure $ PatConstr n (Info t) ps' loc - where - usage = mkUsage loc "matching against constructor" + pure $ PatConstr n info ps' loc checkPat :: [(SizeBinder VName, QualName VName)] -> - PatBase NoInfo VName (TypeBase Size u) -> + Pat (TypeBase Size u) -> Inferred StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a @@ -272,7 +217,7 @@ checkPat sizes p t m = do -- | Check and bind a @let@-pattern. bindingPat :: [SizeBinder VName] -> - PatBase NoInfo VName (TypeBase Size u) -> + Pat (TypeBase Size u) -> StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a @@ -292,7 +237,7 @@ bindingPat sizes p t m = do -- | Check and bind type and value parameters. bindingParams :: [TypeParam] -> - [PatBase NoInfo VName ParamType] -> + [Pat ParamType] -> ([Pat ParamType] -> TermTypeM a) -> TermTypeM a bindingParams tps orig_ps m = bindingTypeParams tps $ do diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ee62ed478f..3258e5db7b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -665,7 +665,10 @@ instance Pretty (Unmatched (Pat StructType)) where pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) -checkRetDecl :: Exp -> Maybe (TypeExp NoInfo VName) -> TermM (Maybe (TypeExp Info VName)) +checkRetDecl :: + Exp -> + Maybe (TypeExp (ExpBase NoInfo VName) VName) -> + TermM (Maybe (TypeExp Exp VName)) checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp te @@ -1007,16 +1010,16 @@ checkExp (Coerce e te NoInfo loc) = do checkValDef :: ( VName, - Maybe (TypeExp NoInfo VName), + Maybe (TypeExp (ExpBase NoInfo VName) VName), [TypeParam], [PatBase NoInfo VName ParamType], ExpBase NoInfo VName, SrcLoc ) -> TypeM - ( [TypeParam], + ( Either T.Text (M.Map TyVar Type), [Pat ParamType], - Maybe (TypeExp Info VName), + Maybe (TypeExp Exp VName), Exp ) checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do @@ -1040,4 +1043,4 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## solution:", either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) solution ] - pure (undefined, params', retdecl', body') + pure (solution, params', retdecl', body') From 8639b4a66fa1e6de507e6bfbc0bb239e093e3e06 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 9 Feb 2024 18:53:36 +0100 Subject: [PATCH 012/296] No more undefined. --- src/Language/Futhark/TypeChecker/Terms.hs | 6 +++--- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 12 ++++++------ src/Language/Futhark/TypeChecker/Terms/Pat.hs | 3 --- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 3e6ecbff19..c27ec06f5f 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1008,7 +1008,7 @@ checkApply loc (fname, prev_applied) ftype argexp = do -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM (checkExp . undefined) $ do +checkOneExp e = runTermTypeM checkExp $ do e' <- checkExp $ undefined e let t = typeOf e' (tparams, _, _) <- @@ -1022,7 +1022,7 @@ checkOneExp e = runTermTypeM (checkExp . undefined) $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM (checkExp . undefined) $ do +checkSizeExp e = runTermTypeM checkExp $ do e' <- checkExp $ undefined e let t = typeOf e' when (hasBinding e') $ @@ -1652,7 +1652,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM (checkExp . undefined) $ do + Right tysubsts -> runTermTypeM checkExp $ do addInitialConstraints tysubsts (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 1b06cdcdd7..49c5226d29 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -196,7 +196,7 @@ data TermEnv = TermEnv { termScope :: TermScope, termChecking :: Maybe Checking, termLevel :: Level, - termChecker :: ExpBase NoInfo VName -> TermTypeM Exp, + termCheckExp :: ExpBase Info VName -> TermTypeM Exp, termOuterEnv :: Env, termImportName :: ImportName } @@ -574,9 +574,9 @@ require why ts e = do mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e pure e -checkExpForSize :: ExpBase NoInfo VName -> TermTypeM Exp +checkExpForSize :: ExpBase Info VName -> TermTypeM Exp checkExpForSize e = do - checker <- asks termChecker + checker <- asks termCheckExp e' <- checker e let t = toStruct $ typeOf e' unify (mkUsage (locOf e') "Size expression") t (Scalar (Prim (Signed Int64))) @@ -584,7 +584,7 @@ checkExpForSize e = do checkTypeExpNonrigid :: TypeExp Exp VName -> TermTypeM (TypeExp Exp VName, ResType, [VName]) checkTypeExpNonrigid te = do - (te', svars, rettype, _l) <- checkTypeExp checkExpForSize $ undefined te + (te', svars, rettype, _l) <- checkTypeExp checkExpForSize te -- No guarantee that the locally bound sizes in rettype are globally -- unique, but we want to turn them into size variables, so let's @@ -636,7 +636,7 @@ initialTermScope = Just (name, EqualityF) addIntrinsicF _ = Nothing -runTermTypeM :: (ExpBase NoInfo VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a +runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a runTermTypeM checker (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv name <- askImportName @@ -647,7 +647,7 @@ runTermTypeM checker (TermTypeM m) = do { termScope = initial_scope, termChecking = Nothing, termLevel = 0, - termChecker = checker, + termCheckExp = checker, termImportName = name, termOuterEnv = outer_env } diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index ad4ea0aa0f..16ad00f710 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -166,9 +166,6 @@ checkPat' sizes (RecordPat fs loc) NoneInferred = RecordPat . M.toList <$> traverse (\p -> checkPat' sizes p NoneInferred) (M.fromList fs) <*> pure loc -checkPat' sizes (PatAscription p t loc) _ = - -- FIXME - PatAscription <$> checkPat' sizes p NoneInferred <*> pure t <*> pure loc checkPat' sizes (PatAscription p t loc) maybe_outer_t = do (t', st, _) <- checkTypeExpNonrigid t From 3f09449594aca2f077b59a151357cde64f1558f1 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 9 Feb 2024 16:05:57 -0800 Subject: [PATCH 013/296] Add ILP/LP solving stuff. --- futhark.cabal | 4 + src/Futhark/Solve/BranchAndBound.hs | 75 ++++++ src/Futhark/Solve/LP.hs | 306 +++++++++++++++++++++++++ src/Futhark/Solve/Matrix.hs | 341 ++++++++++++++++++++++++++++ src/Futhark/Solve/Simplex.hs | 238 +++++++++++++++++++ 5 files changed, 964 insertions(+) create mode 100644 src/Futhark/Solve/BranchAndBound.hs create mode 100644 src/Futhark/Solve/LP.hs create mode 100644 src/Futhark/Solve/Matrix.hs create mode 100644 src/Futhark/Solve/Simplex.hs diff --git a/futhark.cabal b/futhark.cabal index c89eacb214..1c0531a83b 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -370,6 +370,10 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script + Futhark.Solve.LP + Futhark.Solve.Matrix + Futhark.Solve.Simplex + Futhark.Solve.BranchAndBound Futhark.Test Futhark.Test.Spec Futhark.Test.Values diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs new file mode 100644 index 0000000000..846ae4a59a --- /dev/null +++ b/src/Futhark/Solve/BranchAndBound.hs @@ -0,0 +1,75 @@ +module Futhark.Solve.BranchAndBound (branchAndBound) where + +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.LP (LP (..)) +import Futhark.Solve.Matrix +import Futhark.Solve.Simplex + +newtype Bound a = Bound (Maybe a, Maybe a) + deriving (Eq, Ord, Show) + +instance (Ord a) => Semigroup (Bound a) where + Bound (mlb1, mub1) <> Bound (mlb2, mub2) = + Bound (combine max mlb1 mlb2, combine min mub1 mub2) + where + combine _ Nothing b2 = b2 + combine _ b1 Nothing = b1 + combine c (Just b1) (Just b2) = Just $ c b1 b2 + +-- | Solves an LP with the additional constraint that all solutions +-- must be integral. Returns 'Nothing' if infeasible or unbounded. +branchAndBound :: + (Read a, Unbox a, RealFrac a, Show a) => + LP a -> + Maybe (a, Vector Int) +branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt + where + (zopt, mopt) = step (S.singleton mempty) (negate $ read "Infinity") Nothing + step todo zlow opt + | S.null todo = (zlow, opt) + | otherwise = + let (next, rest) = S.deleteFindMin todo + in case simplexLP (mkProblem next) of + Nothing -> step rest zlow opt + Just (z, sol) + | z <= zlow -> step rest zlow opt + | V.all isInt sol -> + step rest z (Just $ V.map round sol) + | otherwise -> + let (idx, frac) = + V.head $ V.filter (not . isInt . snd) $ V.zip (V.generate (V.length sol) id) sol + new_todo = + S.fromList $ + filter + (/= next) + [ M.insertWith (<>) idx (Bound (Nothing, Just $ fromInteger $ floor frac)) next, + M.insertWith (<>) idx (Bound (Just $ fromInteger $ ceiling frac, Nothing)) next + ] + in step (new_todo <> rest) zlow opt + + -- TODO: use isInt x = x == round x + -- requires a better 'rowEchelon' implementation for matrices + isInt x = (abs (fromIntegral (round x) - x)) <= 10 ^^ (-10) + mkProblem = + M.foldrWithKey + ( \idx bound acc -> addBound acc idx bound + ) + prob + + addBound lp idx (Bound (mlb, mub)) = + lp + { lpA = a `addRows` new_rows, + lpd = d V.++ V.fromList new_ds + } + where + (new_rows, new_ds) = + unzip $ + catMaybes + [ (V.generate (ncols a) (\i -> if i == idx then (-1) else 0),) <$> (negate <$> mlb), + (V.generate (ncols a) (\i -> if i == idx then 1 else 0),) <$> mub + ] diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs new file mode 100644 index 0000000000..11e943a1b1 --- /dev/null +++ b/src/Futhark/Solve/LP.hs @@ -0,0 +1,306 @@ +module Futhark.Solve.LP + ( LP (..), + LPE (..), + convert, + normalize, + var, + constant, + cval, + bin, + or, + oneIsZero, + (~+~), + (~-~), + (~*~), + (!), + neg, + linearProgToLP, + linearProgToLPE, + LSum (..), + LinearProg (..), + OptType (..), + Constraint (..), + (==), + (<=), + (>=), + rowEchelonLPE, + ) +where + +import Data.List qualified as L +import Data.Map (Map) +import Data.Map qualified as Map +import Data.Maybe +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.Matrix (Matrix (..)) +import Futhark.Solve.Matrix qualified as M +import Prelude hiding (or, (<=), (==), (>=)) +import Prelude qualified + +-- | A linear program. 'LP c a d' represents the program +-- +-- > maximize c^T * a +-- > subject to a * x <= d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LP a = LP + { lpc :: Vector a, + lpA :: Matrix a, + lpd :: Vector a + } + deriving (Eq, Show) + +-- | Equational form of a linear program. 'LPE c a d' represents the +-- program +-- +-- > maximize c^T * a +-- > subject to a * x = d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LPE a = LPE + { pc :: Vector a, + pA :: Matrix a, + pd :: Vector a + } + deriving (Eq, Show) + +rowEchelonLPE :: (Show a, Unbox a, Fractional a, Ord a) => LPE a -> LPE a +rowEchelonLPE (LPE c a d) = + LPE c (M.sliceCols (V.generate (ncols a) id) ad) (M.getCol (ncols a) ad) + where + ad = + M.filterRows (V.any (Prelude./= 0)) $ + (M.rowEchelon $ a M.<|> M.fromColVector d) + +-- | Converts an 'LP' into an equivalent 'LPE' by introducing slack +-- variables. +convert :: (Show a, Num a, Unbox a) => LP a -> LPE a +convert (LP c a d) = LPE c' a' d + where + a' = a M.<|> M.diagonal (V.replicate (M.nrows a) 1) + c' = c V.++ V.replicate (M.nrows a) 0 + +-- | Linear sum of variables. +newtype LSum v a = LSum {lsum :: (Map (Maybe v) a)} + deriving (Eq) + +instance (Show v, Show a) => Show (LSum v a) where + show (LSum m) = + L.intercalate + " + " + $ map + ( \(k, a) -> + case k of + Nothing -> show a + Just k' -> show a <> "*" <> show k' + ) + $ Map.toList m + +instance Functor (LSum v) where + fmap f (LSum m) = LSum $ fmap f m + +-- | Type of constraint +data CType = Equal | LessEq + deriving (Eq) + +instance Show CType where + show (Equal) = "=" + show (LessEq) = "<=" + +-- | A constraint for a linear program. +data Constraint v a + = Constraint CType (LSum v a) (LSum v a) + deriving (Eq) + +instance (Show a, Show v) => Show (Constraint v a) where + show (Constraint t l r) = + show l <> " " <> show t <> " " <> show r + +data OptType = Maximize | Minimize + deriving (Show, Eq) + +-- | A linear program. +data LinearProg v a = LinearProg + { optType :: OptType, + objective :: LSum v a, + constraints :: [Constraint v a] + } + deriving (Eq) + +instance (Show v, Show a) => Show (LinearProg v a) where + show (LinearProg opt obj cs) = + unlines $ + [ show opt, + show obj, + "subject to:" + ] + ++ map show cs + +bigM :: (Num a) => a +bigM = 10 ^ 3 + +oneIsZero :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +oneIsZero b1 b2 x1 x2 = + mkC b1 x1 + <> mkC b2 x2 + <> [(var b1 ~+~ var b2) <= constant 1] + where + mkC b x = + [ var x <= bigM ~*~ var b + ] + +or :: (Eq a, Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] +or b1 b2 c1 c2 = + mkC b1 c1 + <> mkC b2 c2 + <> [var b1 ~+~ var b2 <= constant 1] + where + mkC b (Constraint Equal l r) = + [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b), + l >= r ~-~ bigM ~*~ (constant 1 ~-~ var b) + ] + mkC b (Constraint LessEq l r) = + [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +bin :: (Num a, Ord v) => v -> Constraint v a +bin v = Constraint LessEq (var v) (constant 1) + +(==) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l == r = Constraint Equal l r + +infix 4 == + +(<=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l <= r = Constraint LessEq l r + +infix 4 <= + +(>=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l >= r = Constraint LessEq (neg l) (neg r) + +infix 4 >= + +normalize :: (Eq a, Num a) => LSum v a -> LSum v a +normalize = LSum . Map.filter (/= 0) . lsum + +var :: (Num a) => v -> LSum v a +var v = LSum $ Map.singleton (Just v) (fromInteger 1) + +constant :: a -> LSum v a +constant = LSum . Map.singleton Nothing + +cval :: (Num a, Ord v) => LSum v a -> a +cval = (! Nothing) + +(~+~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a +(LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y + +infixl 6 ~+~ + +(~-~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a +x ~-~ y = x ~+~ (neg y) + +infixl 6 ~-~ + +(~*~) :: (Eq a, Num a, Ord v) => a -> LSum v a -> LSum v a +a ~*~ s = normalize $ fmap (a *) s + +infixl 7 ~*~ + +(!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a +(LSum m) ! v = + case m Map.!? v of + Nothing -> 0 + Just a -> a + +neg :: (Num a, Ord v) => LSum v a -> LSum v a +neg (LSum x) = LSum $ fmap negate x + +-- | Converts a linear program given with a list of constraints +-- into the standard form. +linearProgToLP :: + forall v a. + (Unbox a, Num a, Ord v, Eq a) => + LinearProg v a -> + (LP a, Map Int v) +linearProgToLP (LinearProg otype obj cs) = + (LP c a d, idxMap) + where + cs' = foldMap (convertEqCType . splitConstraint) cs + idxMap = + Map.fromList $ + zip [0 ..] $ + catMaybes $ + Map.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) + c = mkRow $ convertObj otype obj + a = M.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] + convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] + convertEqCType (LessEq, s, a) = [(s, a)] + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s + +-- | Converts a linear program given with a list of constraints +-- into the equational form. Assumes no <= constraints. +linearProgToLPE :: + forall v a. + (Unbox a, Num a, Ord v, Eq a) => + LinearProg v a -> + (LPE a, Map Int v) +linearProgToLPE (LinearProg otype obj cs) = + (LPE c a d, idxMap) + where + cs' = map (checkOnlyEqType . splitConstraint) cs + idxMap = + Map.fromList $ + zip [0 ..] $ + catMaybes $ + Map.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) + c = mkRow $ convertObj otype obj + a = M.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + checkOnlyEqType :: (CType, LSum v a, a) -> (LSum v a, a) + checkOnlyEqType (Equal, s, a) = (s, a) + checkOnlyEqType (ctype, _, _) = error $ show ctype + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s + +test1 :: LPE Double +test1 = + LPE + { pc = V.fromList [5.5, 2.1], + pA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + pd = V.fromList [2, 17] + } diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs new file mode 100644 index 0000000000..90e1a3e126 --- /dev/null +++ b/src/Futhark/Solve/Matrix.hs @@ -0,0 +1,341 @@ +module Futhark.Solve.Matrix + ( Matrix (..), + toList, + toLists, + fromRowVector, + fromColVector, + fromVectors, + fromLists, + (@), + (!), + sliceCols, + getColM, + getCol, + setCol, + sliceRows, + getRowM, + getRow, + (<|>), + (<->), + addRow, + addRows, + imap, + generate, + identity, + diagonal, + (<.>), + (.*), + (*.), + (.+.), + (.-.), + rowEchelon, + filterRows, + deleteRow, + deleteCol, + ) +where + +import Data.List qualified as L +import Data.Map (Map) +import Data.Map qualified as M +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V + +-- A matrix represented as a 1D 'Vector'. +data Matrix a = Matrix + { elems :: Vector a, + nrows :: Int, + ncols :: Int + } + deriving (Eq) + +instance (Show a, Unbox a) => Show (Matrix a) where + show = + unlines . map show . toLists + +toList :: (Unbox a) => Matrix a -> [Vector a] +toList m = + map (\r -> V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +toLists :: (Unbox a) => Matrix a -> [[a]] +toLists m = + map (\r -> V.toList $ V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +fromRowVector :: (Unbox a) => Vector a -> Matrix a +fromRowVector v = + Matrix + { elems = v, + nrows = 1, + ncols = V.length v + } + +fromColVector :: (Unbox a) => Vector a -> Matrix a +fromColVector v = + Matrix + { elems = v, + nrows = V.length v, + ncols = 1 + } + +empty :: (Unbox a) => Matrix a +empty = Matrix mempty 0 0 + +fromVectors :: (Unbox a) => [Vector a] -> Matrix a +fromVectors [] = empty +fromVectors vs = + Matrix + { elems = V.concat $ vs, + nrows = length vs, + ncols = V.length $ head vs + } + +fromLists :: (Unbox a) => [[a]] -> Matrix a +fromLists xss = + Matrix + { elems = V.concat $ map V.fromList xss, + nrows = length xss, + ncols = length $ head xss + } + +class SelectCols a where + select :: Vector Int -> a -> a + (@) :: a -> Vector Int -> a + (@) = flip select + +infix 9 @ + +instance (Unbox a) => SelectCols (Vector a) where + select s v = V.map (v V.!) s + +instance (Unbox a) => SelectCols (Matrix a) where + select = sliceCols + +(!) :: (Unbox a) => Matrix a -> (Int, Int) -> a +m ! (r, c) = elems m V.! (ncols m * r + c) + +sliceCols :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceCols cols m = + Matrix + { elems = + V.generate (nrows m * V.length cols) $ \i -> + let col = cols V.! (i `rem` V.length cols) + row = i `div` V.length cols + in m ! (row, col), + nrows = nrows m, + ncols = V.length cols + } + +getColM :: (Unbox a) => Int -> Matrix a -> Matrix a +getColM col = sliceCols $ V.singleton col + +getCol :: (Unbox a) => Int -> Matrix a -> Vector a +getCol col = elems . getColM col + +setCol :: (Unbox a) => Int -> Vector a -> Matrix a -> Matrix a +setCol c col m = + m + { elems = + V.update_ (elems m) indices col + } + where + indices = V.generate (nrows m) $ + \r -> r * ncols m + c + +sliceRows :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceRows rows m = + Matrix + { elems = + V.generate (ncols m * V.length rows) $ \i -> + let row = rows V.! (i `rem` V.length rows) + col = i `div` V.length rows + in m ! (row, col), + nrows = V.length rows, + ncols = ncols m + } + +getRowM :: (Unbox a) => Int -> Matrix a -> Matrix a +getRowM row = sliceRows $ V.singleton row + +getRow :: (Unbox a) => Int -> Matrix a -> Vector a +getRow row = elems . getRowM row + +(<|>) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <|> m2 = + generate f (nrows m1) (ncols m1 + ncols m2) + where + f r c + | c < ncols m1 = m1 ! (r, c) + | otherwise = m2 ! (r, c - ncols m1) + +(<->) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <-> m2 = + generate f (nrows m1 + nrows m2) (ncols m1) + where + f r c + | r < nrows m1 = m1 ! (r, c) + | otherwise = m2 ! (r - nrows m1, c) + +addRow :: (Unbox a) => Matrix a -> Vector a -> Matrix a +addRow m v = + m + { elems = elems m V.++ v, + nrows = nrows m + 1 + } + +addRows :: (Unbox a) => Matrix a -> [Vector a] -> Matrix a +addRows = foldl addRow + +imap :: (Unbox a) => (Int -> Int -> a -> a) -> Matrix a -> Matrix a +imap f m = + m + { elems = V.imap g $ elems m + } + where + g i = + let r = i `div` ncols m + c = i `rem` nrows m + in f r c + +generate :: (Unbox a) => (Int -> Int -> a) -> Int -> Int -> Matrix a +generate f rows cols = + Matrix + { elems = + V.generate (rows * cols) $ \i -> + let r = i `div` cols + c = i `rem` cols + in f r c, + nrows = rows, + ncols = cols + } + +identity :: (Unbox a, Num a) => Int -> Matrix a +identity n = generate (\r c -> if r == c then 1 else 0) n n + +diagonal :: (Unbox a, Num a) => Vector a -> Matrix a +diagonal d = generate (\r c -> if r == c then d V.! r else 0) (V.length d) (V.length d) + +(<.>) :: (Unbox a, Num a) => Vector a -> Vector a -> a +v1 <.> v2 = V.sum $ V.zipWith (*) v1 v2 + +infixl 7 <.> + +(*.) :: (Unbox a, Num a) => Matrix a -> Vector a -> Vector a +m *. v = + V.generate (nrows m) $ \r -> + getRow r m <.> v + +infixl 7 *. + +(.*) :: (Unbox a, Num a) => Vector a -> Matrix a -> Vector a +v .* m = + V.generate (ncols m) $ \c -> + v <.> getCol c m + +infixl 7 .* + +(.-.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.-.) = V.zipWith (-) + +infixl 6 .-. + +(.+.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.+.) = V.zipWith (+) + +infixl 6 .+. + +swapRows :: (Unbox a) => Int -> Int -> Matrix a -> Matrix a +swapRows r1 r2 m = + m + { elems = + elems m `V.update` new + } + where + start1 = ncols m * r1 + start2 = ncols m * r2 + row1 = getRow r1 m + row2 = getRow r2 m + new = + V.imap (\i a -> (i + start1, a)) row2 + V.++ V.imap (\i a -> (i + start2, a)) row1 + +-- todo: fix +update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a +update m upds = + generate + ( \i j -> + case (M.fromList $ V.toList upds) M.!? (i, j) of + Nothing -> m ! (i, j) + Just x -> x + ) + (nrows m) + (ncols m) + +update_ :: (Unbox a) => Matrix a -> Map (Int, Int) a -> Matrix a +update_ m upds = + generate + ( \i j -> + case upds M.!? (i, j) of + Nothing -> m ! (i, j) + Just x -> x + ) + (nrows m) + (ncols m) + +-- TODO: maintain integrality of entries in the matrix +-- rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a +-- rowEchelon = rowEchelon' 0 0 +-- where +-- rowEchelon' h k m@(Matrix _ nr nc) +-- | h < nr && k < nc = +-- if m ! (pivot_row, k) == 0 +-- then rowEchelon' h (k + 1) m +-- else rowEchelon' (h + 1) (k + 1) clear_rows_below +-- | otherwise = m +-- where +-- pivot_row = +-- fst $ +-- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ +-- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] +-- m' = swapRows h pivot_row m +-- clear_rows_below = +-- update m' $ +-- V.fromList $ +-- [((i, k), 0) | i <- [h + 1 .. nr - 1]] +-- ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) +-- | i <- [h + 1 .. nr - 1], +-- let f = m' ! (i, k) / m' ! (h, k), +-- j <- [k + 1 .. nc - 1] +-- ] + +rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +rowEchelon = rowEchelon' 0 0 + where + rowEchelon' h k m@(Matrix _ nr nc) + | h < nr && k < nc = + if m ! (pivot_row, k) == 0 + then rowEchelon' h (k + 1) m + else rowEchelon' (h + 1) (k + 1) clear_rows_below + | otherwise = m + where + pivot_row = + fst $ + L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ + [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] + m' = swapRows h pivot_row m + clear_rows_below = + update m' $ + V.fromList $ + [((i, k), 0) | i <- [h + 1 .. nr - 1]] + ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) + | i <- [h + 1 .. nr - 1], + j <- [k + 1 .. nc - 1] + ] + +filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a +filterRows p = fromVectors . filter p . toList + +deleteRow :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteRow n m = sliceRows (V.generate (nrows m - 1) (\r -> if r < n then r else r + 1)) m + +deleteCol :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteCol n m = sliceCols (V.generate (ncols m - 1) (\c -> if c < n then c else c + 1)) m diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs new file mode 100644 index 0000000000..e01c7ce566 --- /dev/null +++ b/src/Futhark/Solve/Simplex.hs @@ -0,0 +1,238 @@ +module Futhark.Solve.Simplex + ( simplex, + simplexLP, + simplexProg, + findBasis, + ) +where + +import Data.List qualified as L +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as M +import Data.Maybe +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) +import Futhark.Solve.Matrix + +-- | A tableau of an equational linear program @a * x = d@ is +-- +-- > x @ b = p + q * x @ n +-- > --------------------- +-- > z = z' + r^T * x @ n +-- +-- where @z = c^T * x@ and @b@ (@n@) is a vector containing the +-- indices of basic (nonbasic) variables. +-- +-- The basic feasible solution corresponding to the above tableau is +-- given by @x \@ b = p@, @x \@n = 0@ with the value of the objective +-- equal to @z'@. + +-- | Computes @r@ as given in the tableau above. +comp_r :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + Vector Int -> + Vector a +comp_r (LPE c a _) invA_B b n = + c @ n .-. c @ b .* invA_B .* a @ n + +-- | @comp_q_enter prob invA_B b n enter@ computes the @enter@th +-- column of @q@. +comp_q_enter :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Int -> + Vector a +comp_q_enter (LPE _ a _) invA_B enter = + V.map negate $ invA_B *. getCol enter a + +-- | Computes the objective given an inversion of @a@ and a basis. +comp_z :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + a +comp_z (LPE c _ d) invA_B b = + c @ b .* invA_B <.> d + +-- | Constructs an auxiliary equational linear program to compute the +-- initial feasible basis; returns the program along with a feasible +-- basis. +mkAux :: (Ord a, Unbox a, Num a) => LPE a -> (LPE a, Vector Int, Vector Int) +mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) + where + c_aux = V.replicate (ncols a) 0 V.++ V.replicate (nrows a) (-1) + d_aux = V.map abs d + a_aux = + imap (\r _ e -> if (d V.! r) < 0 then negate e else e) a + <|> identity (nrows a) + b_aux = V.generate (nrows a) (+ ncols a) + n_aux = V.generate (ncols a) id + +-- | Finds an initial feasible basis for an equational linear program. +-- Returns 'Nothing' if the LP has no solution. Inverts some +-- equations by multiplying by -1 so it also returns a modified (but +-- equivalent) equational linear program. +findBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +findBasis prob = do + (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) + if comp_z p_aux invA_B b == 0 + then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) + else Nothing + where + (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob + invA_B_aux = identity $ V.length b_aux + + fixDegenerateBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Int -> + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (LPE a, Matrix a, Vector a, Vector Int, Vector Int) + fixDegenerateBasis og_prob col prob (invA_B, p, b, n) + | Just exit_idx <- mexit_idx, + V.null (elim_row exit_idx) = + let prob' = + prob + { pA = deleteRow exit_idx (pA prob), + pd = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) $ + pd prob + } + invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B + p' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) p + b' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) b + in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) + | Just exit_idx <- mexit_idx, + (enter, _) <- V.head (elim_row exit_idx) = + let enter_idx = fromJust $ V.findIndex (== enter) n + exit = b V.! exit_idx + in fixDegenerateBasis og_prob col prob $ + pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = + let prob' = + prob + { pc = pc og_prob, + pA = sliceCols (V.generate col id) $ pA prob, + pd = V.map abs $ pd og_prob + } + in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) + where + mexit_idx = + fst <$> ((V.filter ((>= col) . snd) (V.imap (curry id) b)) V.!? 0) + elim_row exit_idx = + V.filter ((/= 0) . snd) $ + V.map (\j -> (j, comp_q_enter prob invA_B j V.! exit_idx)) $ + V.generate col id + +-- | Solves an equational linear program. Returns 'Nothing' if the +-- program is infeasible or unbounded. Otherwise returns the optimal +-- value and the solution. +simplex :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (a, Vector a) +simplex lpe = do + let ech_lpe = rowEchelonLPE lpe + res@(lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe + (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) + let z = comp_z lpe' invA_B' b' + sol = + V.map snd $ + V.fromList $ + L.sortOn fst $ + V.toList $ + V.zip (b' V.++ n') (p' V.++ V.replicate (V.length n') 0) + pure (z, sol) + +-- | Solves a linear program. +simplexLP :: + (Unbox a, Ord a, Fractional a, Show a) => + LP a -> + Maybe (a, Vector a) +simplexLP lp = do + (opt, sol) <- simplex lpe + pure (opt, V.take (ncols $ lpA lp) sol) + where + lpe = convert lp + +simplexProg :: + (Unbox a, Ord a, Ord v, Fractional a, Show a) => + LinearProg v a -> + Maybe (a, Map v a) +simplexProg prog = do + (z, sol) <- simplex lpe + pure $ (z, M.fromList $ map (\(i, x) -> (idxMap M.! i, x)) $ zip [0 ..] $ V.toList sol) + where + (lpe, idxMap) = linearProgToLPE prog + +pivot :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (Int, Int) -> + (Int, Int) -> + (Matrix a, Vector a, Vector Int, Vector Int) +pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = + (invA_B', p', b', n') + where + q_enter = comp_q_enter prob invA_B enter + b' = b V.// [(exit_idx, enter)] + n' = n V.// [(enter_idx, exit)] + e_inv_vec = + V.map + (/ abs (q_enter V.! exit_idx)) + (q_enter V.// [(exit_idx, 1)]) + genF row col = + (if row == exit_idx then 0 else invA_B ! (row, col)) + + (e_inv_vec V.! row) * invA_B ! (exit_idx, col) + invA_B' = generate genF (nrows invA_B) (ncols invA_B) + p' = p V.// [(exit_idx, 0)] .+. V.map (* (p V.! exit_idx)) e_inv_vec + +-- | One step of the simplex algorithm. +step :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + Maybe (Matrix a, Vector a, Vector Int, Vector Int) +step prob (invA_B, p, b, n) + | Just enter_idx <- menter_idx = + let enter = n V.! enter_idx + q_enter = comp_q_enter prob invA_B enter + pq = + V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ + V.filter (\(_, _, q_) -> q_ < 0) $ + V.zip3 (V.generate (V.length q_enter) id) p q_enter + in if V.null pq + then Nothing + else + let exit_val = snd $ V.minimumOn snd pq + exit_cands = + V.map fst $ V.filter ((exit_val ==) . snd) pq + (exit_idx, exit) = + V.minimumOn snd $ + V.map (\i -> (i, b V.! i)) exit_cands + in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = Just (invA_B, p, b, n) + where + r = comp_r prob invA_B b n + menter_idx = V.findIndex (> 0) r + b_zero = V.filter (\(v, i) -> v == 0 && (not $ V.null (V.filter (< i) n))) $ V.zip p b From c98bbe406ace88272d1dd0e472901dc71cf52f00 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 9 Feb 2024 17:37:37 -0800 Subject: [PATCH 014/296] Add AM AST annotation. --- src/Futhark/Internalise/Defunctionalise.hs | 14 +++++----- src/Futhark/Internalise/Exps.hs | 2 +- src/Futhark/Internalise/FullNormalise.hs | 6 ++-- src/Futhark/Internalise/LiftLambdas.hs | 2 +- src/Futhark/Internalise/Monomorphise.hs | 16 +++++------ src/Language/Futhark/Interpreter.hs | 2 +- src/Language/Futhark/Syntax.hs | 28 ++++++++++++++++--- .../Futhark/TypeChecker/Consumption.hs | 5 ++-- src/Language/Futhark/TypeChecker/Terms.hs | 4 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 10 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 8ad2e15948..d6b03a368b 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -752,7 +752,7 @@ etaExpand e_t e = do M.fromList . zip (retDims ret) $ map (ExpSubst . flip sizeFromName mempty . qualName) ext' ret' = applySubst (`M.lookup` extsubst) ret - e' = mkApply e (map (Nothing,) vars) $ AppRes (toStruct $ retType ret') ext' + e' = mkApply e (map (\v -> (Nothing, mempty, v)) vars) $ AppRes (toStruct $ retType ret') ext' pure (params, e', ret) where getType (RetType _ (Scalar (Arrow _ p d t1 t2))) = @@ -910,9 +910,9 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - ((Maybe VName, Exp), [ParamType]) -> + (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg let env' = alwaysMatchPatSV pat arg_sv dims = mempty @@ -963,18 +963,18 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, ar callret <- unRetType lifted_rettype pure - ( mkApply fname' [(Nothing, f'), (argext, arg')] callret, + ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) ((argext, arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, arg')] callret + apply_e = mkApply f' [(argext, mempty, arg')] callret pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = @@ -991,7 +991,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty (Maybe VName, Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index b2a8f37247..ec4adece0b 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1484,7 +1484,7 @@ findFuncall (E.Apply f args _) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info argext, e) = (e, argext) + onArg (Info (argext, _), e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index f797557776..f566f43e64 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -202,7 +202,7 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (I let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, x), (Nothing, y)] $ + mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -215,7 +215,7 @@ getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext)) ( y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, x), (yext, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn @@ -304,7 +304,7 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) lo (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, el'), (erp, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT nameExp final expr' where isOr = baseName (qualLeaf op) == "||" diff --git a/src/Futhark/Internalise/LiftLambdas.hs b/src/Futhark/Internalise/LiftLambdas.hs index 0c9aead794..f515083170 100644 --- a/src/Futhark/Internalise/LiftLambdas.hs +++ b/src/Futhark/Internalise/LiftLambdas.hs @@ -138,7 +138,7 @@ liftFunction fname tparams params (RetType dims ret) funbody = do apply f [] = f apply f (p : rem_ps) = let inner_ret = AppRes (augType rem_ps) mempty - inner = mkApply f [(Nothing, freeVar p)] inner_ret + inner = mkApply f [(Nothing, mempty, freeVar p)] inner_ret in apply inner rem_ps transformSubExps :: ASTMapper LiftM diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index e26cc1cb28..935627bccd 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -121,10 +121,10 @@ entryAssert (x : xs) body = andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty eqop = Var (qualName (intrinsicVar "==")) (Info opt) mempty logAnd x' y = - mkApply andop [(Nothing, x'), (Nothing, y)] $ + mkApply andop [(Nothing, mempty, x'), (Nothing, mempty, y)] $ AppRes bool [] cmpExp (ReplacedExp x', y) = - mkApply eqop [(Nothing, x'), (Nothing, y')] $ + mkApply eqop [(Nothing, mempty, x'), (Nothing, mempty, y')] $ AppRes bool [] where y' = Var (qualName y) (Info i64) mempty @@ -415,7 +415,7 @@ transformFName loc fname t = do ( i - 1, mkApply f - [(Nothing, size_arg)] + [(Nothing, mempty, size_arg)] (AppRes (foldFunType (replicate i i64) (RetType [] t')) []) ) @@ -539,7 +539,7 @@ transformAppExp (Apply fe args _) res = <*> mapM onArg (NE.toList args) <*> transformAppRes res where - onArg (Info ext, e) = (ext,) <$> transformExp e + onArg (Info (ext, am), e) = (ext,am,) <$> transformExp e transformAppExp (Loop sparams pat e1 form body loc) res = do e1' <- transformExp e1 @@ -603,8 +603,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(unInfo d1, x)] (AppRes ret mempty)) - [(unInfo d2, y)] + (mkApply fname' [(unInfo d1, mempty, x)] (AppRes ret mempty)) + [(unInfo d2, mempty, y)] (AppRes ret ext) makeVarParam arg = do @@ -790,7 +790,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( let apply_left = mkApply op - [(xext, e1)] + [(xext, mempty, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -799,7 +799,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, e2)] + mkApply apply_left [(yext, mempty, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 5a9030aa65..ac76cf6645 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -855,7 +855,7 @@ evalAppExp env (Apply f args loc) = do f' <- eval env f foldM (apply loc env) f' args' where - evalArg' (Info ext, x) = evalArg env x ext + evalArg' (Info (ext, _), x) = evalArg env x ext evalAppExp env (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index e47d7d19f3..e85009ad4a 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -23,6 +23,7 @@ module Language.Futhark.Syntax Shape (..), shapeRank, stripDims, + AutoMap (..), TypeBase (..), TypeArg (..), SizeExp (..), @@ -230,7 +231,10 @@ sizeFromInteger x = IntLit x (Info <$> Scalar $ Prim $ Signed Int64) -- | The size of an array type is a list of its dimension sizes. If -- 'Nothing', that dimension is of a (statically) unknown size. -newtype Shape dim = Shape {shapeDims :: [dim]} +data Shape dim + = Shape {shapeDims :: [dim]} + | SVar VName + | SConcat (Shape dim) (Shape dim) deriving (Eq, Ord, Show) instance Foldable Shape where @@ -244,6 +248,9 @@ instance Functor Shape where instance Semigroup (Shape dim) where Shape l1 <> Shape l2 = Shape $ l1 ++ l2 + Shape [] <> s = s + s <> Shape [] = s + s1 <> s2 = s1 `SConcat` s2 instance Monoid (Shape dim) where mempty = Shape [] @@ -260,6 +267,19 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } + deriving (Eq, Show, Ord) + +instance Semigroup AutoMap where + (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) + +instance Monoid AutoMap where + mempty = AutoMap mempty mempty mempty + -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' -- instances always compare values of this type equal. data PName = Named VName | Unnamed @@ -630,7 +650,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -1258,7 +1278,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of @@ -1270,7 +1290,7 @@ mkApply f args (AppRes t ext) AppExp (Apply f args' (srcspan f $ snd $ NE.last args')) (Info (AppRes t ext)) | otherwise = f where - onArg (v, x) = (Info v, x) + onArg (v, am, x) = (Info (v, am), x) -- | Construct an 'Apply' node, without type information. mkApplyUT :: ExpBase NoInfo vn -> ExpBase NoInfo vn -> ExpBase NoInfo vn diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index d971f48def..5c1198537f 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -710,12 +710,13 @@ checkExp (AppExp (Apply f args loc) appres) = do res_als <- checkFuncall loc (fname f) f_als args_als pure (AppExp (Apply f' args' loc) appres, res_als) where + -- neUnzip3 xs = ((\(x, _, _) -> x) <$> xs, (\(_, y, _) -> y) <$> xs, (\(_, _, z) -> z) <$> xs) fname (Var v _ _) = Just v fname (AppExp (Apply e _ _) _) = fname e fname _ = Nothing - checkArg' prev d (Info p, e) = do + checkArg' prev d (Info (p, am), e) = do (e', e_als) <- checkArg prev (second (const d) (typeOf e)) e - pure ((Info p, e'), e_als) + pure ((Info (p, am), e'), e_als) checkArgs (Scalar (Arrow _ _ d _ (RetType _ rt))) (x NE.:| args') = do -- Note Futhark uses right-to-left evaluation of applications. diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index c27ec06f5f..59fc6d5ba3 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -547,7 +547,7 @@ checkExp (AppExp (Apply fe args loc) _) = do (_, rt, argext, exts) <- checkApply loc (fname, i) t arg' pure ( (i + 1, all_exts <> exts, rt), - (Info argext, arg') + (Info (argext, mempty), arg') ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -1099,7 +1099,7 @@ causalityCheck binding_body = do seqArgs known' [] = do void $ onExp known' f modify (S.fromList (appResExt res) <>) - seqArgs known' ((Info p, x) : xs) = do + seqArgs known' ((Info (p, _), x) : xs) = do new_known <- collectingNewKnown $ onExp known' x void $ seqArgs (new_known <> known') xs modify ((new_known <> S.fromList (maybeToList p)) <>) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b822636581..85ade3c1f7 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -782,7 +782,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do rt <- checkApply loc (fname, i) (toType f_t) arg' pure ( (i + 1, rt), - (Info Nothing, arg') + (Info (Nothing, mempty), arg') ) -- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do From db8db0cf02a80c8083422f26a020de14fc41d72b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 10 Feb 2024 11:34:34 +0100 Subject: [PATCH 015/296] More work on using the type information. --- .../Futhark/TypeChecker/Constraints.hs | 32 ++++++++++++--- src/Language/Futhark/TypeChecker/Terms.hs | 19 +++++---- .../Futhark/TypeChecker/Terms/Monad.hs | 32 +++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 40 ++++++++++--------- 4 files changed, 74 insertions(+), 49 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 74a2dd2fd0..bd7c12f7bb 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -6,6 +6,7 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, + Solution, solve, ) where @@ -94,12 +95,31 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt -solution :: SolverState -> M.Map TyVar Type -solution s = M.mapMaybe f $ solverTyVars s +-- | A solution maps types to the set of type variables that must be +-- substituted with this type. This slightly odd representation is +-- needed to encode when two type variables are actually the same +-- type. This matters when we start instanting the sizes of the type. +type Solution = M.Map Type [TyVar] + +solution :: SolverState -> Solution +solution s = + M.fromList $ + map adjust $ + M.toList $ + foldl addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ + M.toList $ + solverTyVars s where - f (TyVarSol t) = Just $ substTyVars (solverTyVars s) t - f (TyVarLink v) = f =<< M.lookup v (solverTyVars s) - f (TyVarUnsol _) = Nothing + mkSubst (TyVarSol t) = Just (t, []) + mkSubst _ = Nothing + addLinks m (v1, TyVarLink v2) = + case M.lookup v2 $ solverTyVars s of + Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) + _ -> case M.lookup v2 m of + Nothing -> m + Just (t, vs) -> M.insert v2 (t, v1 : vs) m + addLinks m _ = m + adjust (v, (t, vs)) = (t, v : vs) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) @@ -172,7 +192,7 @@ solveCt ct = do Nothing -> bad Just eqs -> mapM_ solveCt' eqs -solve :: Constraints -> TyVars -> Either T.Text (M.Map TyVar Type) +solve :: Constraints -> TyVars -> Either T.Text Solution solve constraints tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 59fc6d5ba3..b55dca1ad8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -27,6 +27,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Debug.Trace import Futhark.Util (mapAccumLM, nubOrd, topologicalSort) import Futhark.Util.Pretty hiding (space) import Language.Futhark @@ -522,9 +523,10 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn _ loc) = do - t <- lookupVar loc qn - pure $ Var qn (Info t) loc +checkExp (Var qn (Info t) loc) = do + t' <- lookupVar loc qn + unify (mkUsage loc "inferred rank type") t t' + pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg pure $ Negate arg' loc @@ -1620,14 +1622,15 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t -addInitialConstraints :: M.Map VName (TypeBase () NoUniqueness) -> TermTypeM () +addInitialConstraints :: M.Map (TypeBase () NoUniqueness) [VName] -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where addConstraint v c = modifyConstraints $ M.insert v (0, c) - usage = mkUsage (mempty :: Loc) "trust me bro" - f (v, t) = do - (t', _) <- allDimsFreshInType usage Nonrigid "dv" t - addConstraint v $ Constraint (RetType [] t') usage + usage = mkUsage (mempty :: Loc) + f (t, vs) = do + (t', _) <- allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t + forM_ vs $ \v -> + addConstraint v $ Constraint (RetType [] t') $ usage $ prettyNameText v -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 49c5226d29..e8525e0dfd 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -347,22 +347,6 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." --- | Instantiate a type scheme with fresh type variables for its type --- parameters. Returns the names of the fresh type variables, the --- instance list, and the instantiated type. -instantiateTypeScheme :: - QualName VName -> - SrcLoc -> - [TypeParam] -> - StructType -> - TermTypeM ([VName], StructType) -instantiateTypeScheme qn loc tparams t = do - let tnames = map typeParamName tparams - (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams - let substs = M.fromList $ zip tnames tparam_substs - t' = applySubst (`M.lookup` substs) t - pure (tparam_names, t') - -- | Create a new type name and insert it (unconstrained) in the -- substitution map. instantiateTypeParam :: @@ -385,6 +369,22 @@ instantiateTypeParam qn loc tparam = do "instantiated size parameter of " <> dquotes (pretty qn) pure (v, ExpSubst $ sizeFromName (qualName v) loc) +-- | Instantiate a type scheme with fresh type variables for its type +-- parameters. Returns the names of the fresh type variables, the +-- instance list, and the instantiated type. +instantiateTypeScheme :: + QualName VName -> + SrcLoc -> + [TypeParam] -> + StructType -> + TermTypeM ([VName], StructType) +instantiateTypeScheme qn loc tparams t = do + let tnames = map typeParamName tparams + (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams + let substs = M.fromList $ zip tnames tparam_substs + t' = applySubst (`M.lookup` substs) t + pure (tparam_names, t') + lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) | baseTag q <= maxIntrinsicTag = asks termScope -- Magical intrinsic module. diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 85ade3c1f7..b8c17c57c9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -366,23 +366,24 @@ lookupMod qn@(QualName _ name) = do lookupVar :: SrcLoc -> QualName VName -> TermM StructType lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn - case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do - if null tparams && null qs - then pure t - else do - (tnames, t') <- instTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newType loc "t" - pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeOverloaded loc "t" ts - let (pts', rt') = instOverloaded (argtype :: StructType) pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + asStructType loc + =<< case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newType loc "t" + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeOverloaded loc "t" ts + let (pts', rt') = instOverloaded (argtype :: StructType) pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -1018,7 +1019,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text (M.Map TyVar Type), + ( Either T.Text Solution, [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1042,6 +1043,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, "## solution:", - either T.unpack (unlines . map (prettyString . first prettyNameString) . M.toList) solution + let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') From 2bd5fe96e372b70d66cd2ace5bdc65cb4a4d27fd Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 10:14:55 -0800 Subject: [PATCH 016/296] Start adding AUTOMAP machinery. --- src/Futhark/Internalise/Defunctionalise.hs | 4 +- src/Language/Futhark/Syntax.hs | 24 +++----- .../Futhark/TypeChecker/Constraints.hs | 36 ++++++++---- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 58 ++++++++++++------- 5 files changed, 76 insertions(+), 48 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index d6b03a368b..6cd66fbce6 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -910,7 +910,7 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - (((Maybe VName, AutoMap), Exp), [ParamType]) -> + (((Maybe VName, AutoMap Size), Exp), [ParamType]) -> DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg @@ -991,7 +991,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap Size), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index b4b56aecf6..d1e98a165f 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -231,10 +231,7 @@ sizeFromInteger x = IntLit x (Info <$> Scalar $ Prim $ Signed Int64) -- | The size of an array type is a list of its dimension sizes. If -- 'Nothing', that dimension is of a (statically) unknown size. -data Shape dim - = Shape {shapeDims :: [dim]} - | SVar VName - | SConcat (Shape dim) (Shape dim) +newtype Shape dim = Shape {shapeDims :: [dim]} deriving (Eq, Ord, Show) instance Foldable Shape where @@ -248,9 +245,6 @@ instance Functor Shape where instance Semigroup (Shape dim) where Shape l1 <> Shape l2 = Shape $ l1 ++ l2 - Shape [] <> s = s - s <> Shape [] = s - s1 <> s2 = s1 `SConcat` s2 instance Monoid (Shape dim) where mempty = Shape [] @@ -267,17 +261,17 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap = AutoMap - { autoRep :: Shape Size, - autoMap :: Shape Size, - autoFrame :: Shape Size +data AutoMap u = AutoMap + { autoRep :: Shape u, + autoMap :: Shape u, + autoFrame :: Shape u } deriving (Eq, Show, Ord) -instance Semigroup AutoMap where +instance Semigroup (AutoMap u) where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) -instance Monoid AutoMap where +instance Monoid (AutoMap u) where mempty = AutoMap mempty mempty mempty -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' @@ -716,7 +710,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap Size), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -1344,7 +1338,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap Size, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index bd7c12f7bb..b6b5507283 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,7 @@ module Language.Futhark.TypeChecker.Constraints - ( Type, + ( SVar, + SComp (..), + Type, toType, Ct (..), Constraints, @@ -16,29 +18,43 @@ import Control.Monad.State import Data.Bifunctor import Data.Map qualified as M import Data.Text qualified as T +import Futhark.IR.Pretty import Futhark.Util.Pretty import Language.Futhark --- | A shape component is currently just unit. The rank of an array is --- then just the number of shape components it contains in its shape --- list. When we add AUTOMAP, these components will also allow shape --- variables. The list of components should then be understood as --- concatenation of shapes (meaning you can't just take the length to --- determine the rank of the array). -type SComp = () +type SVar = VName + +-- | A shape component. `SDim` is a single dimension of unspecified +-- size, `SVar` is a shape variable. A list of shape components should +-- then be understood as concatenation of shapes (meaning you can't +-- just take the length to determine the rank of the array). +data SComp + = SDim + | SVar SVar + deriving (Eq, Ord, Show) + +instance Pretty SComp where + pretty (SDim) = "[]" + pretty (SVar x) = pretty x + +instance Pretty (Shape SComp) where + pretty = mconcat . map (brackets . pretty) . shapeDims -- | The type representation used by the constraint solver. Agnostic -- to sizes. type Type = TypeBase SComp NoUniqueness toType :: TypeBase d u -> Type -toType = bimap (const ()) (const NoUniqueness) +toType = bimap (const SDim) (const NoUniqueness) -data Ct = CtEq Type Type +data Ct + = CtEq Type Type + | CtAM SVar SVar deriving (Show) instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtAM r m) = pretty r <+> "=" <+> "•" <+> "∨" <+> pretty m <+> "=" <+> "•" type Constraints = [Ct] diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b55dca1ad8..2a7bbbd6da 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1656,7 +1656,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right tysubsts -> runTermTypeM checkExp $ do - addInitialConstraints tysubsts + addInitialConstraints $ M.mapKeys (first $ const ()) tysubsts (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b8c17c57c9..a037e0f2ec 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -235,6 +235,12 @@ newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBa newTypeOverloaded loc name pts = tyVarType <$> newTyVarWith loc name (TyVarPrim pts) +newSVar :: (Located loc) => loc -> Name -> TermM SVar +newSVar _loc desc = do + i <- incCounter + v <- newID $ mkTypeVarName desc i + pure v + asStructType :: (Monoid u) => SrcLoc -> TypeBase d u -> TermM (TypeBase Size u) asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] @@ -259,6 +265,9 @@ ctEq t1 t2 = t1' = toType t1 t2' = toType t2 +ctAM :: SVar -> SVar -> TermM () +ctAM r m = addCt $ CtAM r m + localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -328,7 +337,7 @@ instance MonadTypeChecker TermM where arrayOfRank :: Int -> Type -> Type arrayOfRank 0 t = t -arrayOfRank n t = arrayOf (Shape $ replicate n ()) t +arrayOfRank n t = arrayOf (Shape $ replicate n SDim) t require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why pts e = do @@ -346,13 +355,14 @@ instTypeScheme :: StructType -> TermM ([VName], StructType) instTypeScheme _qn loc tparams t = do - (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> do - case tparam of - TypeParamType _ v _ -> do - v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) - TypeParamDim {} -> - pure Nothing + (names, substs) <- fmap (unzip . catMaybes) $ + forM tparams $ \tparam -> do + case tparam of + TypeParamType _ v _ -> do + v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) + TypeParamDim {} -> + pure Nothing let t' = applySubst (`lookup` substs) t pure (names, t') @@ -575,16 +585,24 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM Type -checkApply _ _ (Scalar (Arrow _ _ _ a (RetType _ b))) arg = do - ctEq a $ expType arg - pure $ toType b +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap SComp) checkApply loc _ ftype arg = do - a <- newType loc "arg" - b <- newTyVar loc "res" - ctEq ftype $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) - ctEq a (expType arg) - pure $ tyVarType b + (a, b) <- split ftype + r <- newSVar loc "R" + m <- newSVar loc "M" + let s_r = Shape $ pure $ SVar r + s_m = Shape $ pure $ SVar m + ctAM r m + ctEq (arrayOf s_r $ toType $ typeOf arg) (arrayOf s_m a) + pure (arrayOf s_m b, AutoMap {autoRep = s_r, autoMap = s_m, autoFrame = mempty}) + where + split (Scalar (Arrow _ _ _ a (RetType _ b))) = + pure (a, toType b) + split ftype' = do + a <- newType loc "arg" + b <- newTyVar loc "res" + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) + pure (a, tyVarType b) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex @@ -780,7 +798,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do onArg (i, f_t) (_, arg) = do arg' <- checkExp arg - rt <- checkApply loc (fname, i) (toType f_t) arg' + (rt, am) <- checkApply loc (fname, i) (toType f_t) arg' pure ( (i + 1, rt), (Info (Nothing, mempty), arg') @@ -791,8 +809,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e1' <- checkExp e1 e2' <- checkExp e2 - rt1 <- checkApply loc (Just op, 0) (toType ftype) e1' - rt2 <- checkApply loc (Just op, 1) rt1 e2' + (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) e1' + (rt2, am2) <- checkApply loc (Just op, 1) rt1 e2' rt2' <- asStructType loc rt2 pure $ From 55cb35465d4f8504af712db62e4c4e7469db22ec Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 11:00:06 -0800 Subject: [PATCH 017/296] Use `Shape Size` for AUTOMAP. --- src/Futhark/Internalise/Defunctionalise.hs | 4 ++-- src/Language/Futhark/Syntax.hs | 16 ++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 15 ++++++++++----- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 6cd66fbce6..d6b03a368b 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -910,7 +910,7 @@ liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> - (((Maybe VName, AutoMap Size), Exp), [ParamType]) -> + (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg @@ -991,7 +991,7 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap Size), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index d1e98a165f..f0f2a586df 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -261,17 +261,17 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap u = AutoMap - { autoRep :: Shape u, - autoMap :: Shape u, - autoFrame :: Shape u +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size } deriving (Eq, Show, Ord) -instance Semigroup (AutoMap u) where +instance Semigroup AutoMap where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) -instance Monoid (AutoMap u) where +instance Monoid AutoMap where mempty = AutoMap mempty mempty mempty -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' @@ -710,7 +710,7 @@ data AppExpBase f vn -- identical). Apply (ExpBase f vn) - (NE.NonEmpty (f (Maybe VName, AutoMap Size), ExpBase f vn)) + (NE.NonEmpty (f (Maybe VName, AutoMap), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) @@ -1338,7 +1338,7 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. -mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap Size, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply :: ExpBase Info vn -> [(Maybe VName, AutoMap, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index a037e0f2ec..e6e9b26776 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -585,17 +585,22 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap SComp) +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap) checkApply loc _ ftype arg = do (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" - let s_r = Shape $ pure $ SVar r - s_m = Shape $ pure $ SVar m + let unit_info = Info $ Scalar $ Prim Bool + r_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] r) unit_info mempty ctAM r m - ctEq (arrayOf s_r $ toType $ typeOf arg) (arrayOf s_m a) - pure (arrayOf s_m b, AutoMap {autoRep = s_r, autoMap = s_m, autoFrame = mempty}) + ctEq (arrayOf (toShape $ SVar r) $ toType $ typeOf arg) (arrayOf (toShape $ SVar m) a) + pure + ( arrayOf (toShape $ SVar m) b, + AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = mempty} + ) where + toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, toType b) split ftype' = do From a67a1158f7870565f3572c324aecb5b8f3d5502c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 13:20:39 -0800 Subject: [PATCH 018/296] Add rank analysis stuff. --- futhark.cabal | 1 + src/Futhark/Solve/LP.hs | 40 ++++---- src/Language/Futhark/TypeChecker/Rank.hs | 108 +++++++++++++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 4 files changed, 130 insertions(+), 21 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/Rank.hs diff --git a/futhark.cabal b/futhark.cabal index 1c0531a83b..80006d4cc9 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -418,6 +418,7 @@ library Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules Language.Futhark.TypeChecker.Monad + Language.Futhark.TypeChecker.Rank Language.Futhark.TypeChecker.Terms Language.Futhark.TypeChecker.Terms2 Language.Futhark.TypeChecker.Terms.Loop diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 11e943a1b1..3b46af1965 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -20,9 +20,9 @@ module Futhark.Solve.LP LinearProg (..), OptType (..), Constraint (..), - (==), - (<=), - (>=), + (~==~), + (~<=~), + (~>=~), rowEchelonLPE, ) where @@ -143,47 +143,47 @@ instance (Show v, Show a) => Show (LinearProg v a) where bigM :: (Num a) => a bigM = 10 ^ 3 -oneIsZero :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] -oneIsZero b1 b2 x1 x2 = +oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] +oneIsZero (b1, x1) (b2, x2) = mkC b1 x1 <> mkC b2 x2 - <> [(var b1 ~+~ var b2) <= constant 1] + <> [(var b1 ~+~ var b2) ~<=~ constant 1] where mkC b x = - [ var x <= bigM ~*~ var b + [ var x ~<=~ bigM ~*~ var b ] or :: (Eq a, Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] or b1 b2 c1 c2 = mkC b1 c1 <> mkC b2 c2 - <> [var b1 ~+~ var b2 <= constant 1] + <> [var b1 ~+~ var b2 ~<=~ constant 1] where mkC b (Constraint Equal l r) = - [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b), - l >= r ~-~ bigM ~*~ (constant 1 ~-~ var b) + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b), + l ~>=~ r ~-~ bigM ~*~ (constant 1 ~-~ var b) ] mkC b (Constraint LessEq l r) = - [ l <= r ~+~ bigM ~*~ (constant 1 ~-~ var b) + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) ] bin :: (Num a, Ord v) => v -> Constraint v a bin v = Constraint LessEq (var v) (constant 1) -(==) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a -l == r = Constraint Equal l r +(~==~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l ~==~ r = Constraint Equal l r -infix 4 == +infix 4 ~==~ -(<=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a -l <= r = Constraint LessEq l r +(~<=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l ~<=~ r = Constraint LessEq l r -infix 4 <= +infix 4 ~<=~ -(>=) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a -l >= r = Constraint LessEq (neg l) (neg r) +(~>=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +l ~>=~ r = Constraint LessEq (neg l) (neg r) -infix 4 >= +infix 4 ~>=~ normalize :: (Eq a, Num a) => LSum v a -> LSum v a normalize = LSum . Map.filter (/= 0) . lsum diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs new file mode 100644 index 0000000000..e9bf7de859 --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -0,0 +1,108 @@ +module Language.Futhark.TypeChecker.Rank (rankAnalysis) where + +import Control.Monad.State +import Data.Map (Map) +import Data.Map qualified as M +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.BranchAndBound +import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) +import Futhark.Solve.LP qualified as LP +import Language.Futhark hiding (ScalarType) +import Language.Futhark.TypeChecker.Constraints + +type LSum = LP.LSum VName Double + +type Constraint = LP.Constraint VName Double + +type LinearProg = LP.LinearProg VName Double + +type ScalarType = ScalarTypeBase SComp NoUniqueness + +class Rank a where + rank :: a -> LSum + +instance Rank SComp where + rank SDim = constant 1 + rank (SVar v) = var v + +instance Rank (Shape SComp) where + rank = foldr (\d r -> rank d ~+~ r) (constant 0) . shapeDims + +instance Rank ScalarType where + rank Prim {} = constant 0 + rank (TypeVar _ (QualName [] v) []) = var v + rank (Arrow {}) = constant 0 + rank t = error $ prettyString t + +instance Rank Type where + rank (Scalar t) = rank t + rank (Array _ shape t) = rank shape ~+~ rank t + +data RankState = RankState + { rankBinVars :: Map VName VName, + rankCounter :: !Int, + rankConstraints :: [Constraint] + } + +newtype RankM a = RankM {runRankM :: State RankState a} + deriving (Functor, Applicative, Monad, MonadState RankState) + +incCounter :: RankM Int +incCounter = do + s <- get + put s {rankCounter = rankCounter s + 1} + pure $ rankCounter s + +binVar :: VName -> RankM (VName) +binVar sv = do + mbv <- (M.!? sv) <$> gets rankBinVars + case mbv of + Nothing -> do + bv <- VName ("b_" <> baseName sv) <$> incCounter + modify $ \s -> + s + { rankBinVars = M.insert sv bv $ rankBinVars s, + rankConstraints = rankConstraints s ++ [bin bv] + } + pure bv + Just bv -> pure bv + +addConstraints :: [Constraint] -> RankM () +addConstraints cs = + modify $ \s -> s {rankConstraints = rankConstraints s ++ cs} + +addConstraint :: Constraint -> RankM () +addConstraint = addConstraints . pure + +addCt :: Ct -> RankM () +addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 +addCt (CtAM r m) = do + b_r <- binVar r + b_m <- binVar m + addConstraints $ oneIsZero (b_r, r) (b_m, m) + +mkLinearProg :: Int -> [Ct] -> LinearProg +mkLinearProg counter cs = + LP.LinearProg + { optType = Minimize, + objective = + let shape_vars = M.keys $ rankBinVars finalState + in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, + constraints = rankConstraints finalState + } + where + initState = + RankState + { rankBinVars = mempty, + rankCounter = counter, + rankConstraints = mempty + } + finalState = flip execState initState $ runRankM $ mapM_ addCt cs + +rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) +rankAnalysis counter cs = do + (_size, ranks) <- branchAndBound lp + pure $ (ranks V.!) <$> inv_var_map + where + (lp, var_map) = linearProgToLP $ mkLinearProg counter cs + inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e6e9b26776..a9f6444980 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -806,7 +806,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do (rt, am) <- checkApply loc (fname, i) (toType f_t) arg' pure ( (i + 1, rt), - (Info (Nothing, mempty), arg') + (Info (Nothing, am), arg') ) -- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do From 9318e211843774be85594f22097ca8b3ac75b0d5 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 22:32:41 -0800 Subject: [PATCH 019/296] Starting to integrate the rank solver. --- src/Futhark/Solve/LP.hs | 55 ++++++++++++---------- src/Language/Futhark/TypeChecker/Rank.hs | 32 ++++++++++++- src/Language/Futhark/TypeChecker/Terms2.hs | 33 +++++++++++++ 3 files changed, 93 insertions(+), 27 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 3b46af1965..d5c0ee6c5e 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -36,7 +36,9 @@ import Data.Vector.Unboxed qualified as V import Debug.Trace import Futhark.Solve.Matrix (Matrix (..)) import Futhark.Solve.Matrix qualified as M -import Prelude hiding (or, (<=), (==), (>=)) +import Futhark.Util.Pretty +import Language.Futhark.Pretty +import Prelude hiding (or) import Prelude qualified -- | A linear program. 'LP c a d' represents the program @@ -86,17 +88,16 @@ convert (LP c a d) = LPE c' a' d -- | Linear sum of variables. newtype LSum v a = LSum {lsum :: (Map (Maybe v) a)} - deriving (Eq) + deriving (Show, Eq) -instance (Show v, Show a) => Show (LSum v a) where - show (LSum m) = - L.intercalate - " + " +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where + pretty (LSum m) = + concatWith (surround " + ") $ map ( \(k, a) -> case k of - Nothing -> show a - Just k' -> show a <> "*" <> show k' + Nothing -> pretty a + Just k' -> (if a == 1 then mempty else pretty a <> "·") <> prettyName k' ) $ Map.toList m @@ -105,40 +106,44 @@ instance Functor (LSum v) where -- | Type of constraint data CType = Equal | LessEq - deriving (Eq) + deriving (Show, Eq) -instance Show CType where - show (Equal) = "=" - show (LessEq) = "<=" +instance Pretty CType where + pretty Equal = "=" + pretty LessEq = "<=" -- | A constraint for a linear program. data Constraint v a = Constraint CType (LSum v a) (LSum v a) - deriving (Eq) + deriving (Show, Eq) -instance (Show a, Show v) => Show (Constraint v a) where - show (Constraint t l r) = - show l <> " " <> show t <> " " <> show r +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where + pretty (Constraint t l r) = + pretty l <+> pretty t <+> pretty r data OptType = Maximize | Minimize deriving (Show, Eq) +instance Pretty OptType where + pretty Maximize = "maximize" + pretty Minimize = "minimize" + -- | A linear program. data LinearProg v a = LinearProg { optType :: OptType, objective :: LSum v a, constraints :: [Constraint v a] } - deriving (Eq) - -instance (Show v, Show a) => Show (LinearProg v a) where - show (LinearProg opt obj cs) = - unlines $ - [ show opt, - show obj, - "subject to:" + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where + pretty (LinearProg opt obj cs) = + vcat $ + [ pretty opt, + indent 2 $ pretty obj, + "subject to", + indent 2 $ vcat $ map pretty cs ] - ++ map show cs bigM :: (Num a) => a bigM = 10 ^ 3 diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index e9bf7de859..5ef9f72594 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -3,10 +3,13 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where import Control.Monad.State import Data.Map (Map) import Data.Map qualified as M +import Data.Maybe import Data.Vector.Unboxed qualified as V +import Debug.Trace import Futhark.Solve.BranchAndBound import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP +import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints @@ -38,6 +41,20 @@ instance Rank Type where rank (Scalar t) = rank t rank (Array _ shape t) = rank shape ~+~ rank t +class Distribute a where + distribute :: a -> a + +instance Distribute Type where + distribute = distributeOne + where + distributeOne (Array _ s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s $ tr) + distributeOne t = t + +instance Distribute Ct where + distribute (CtEq t1 t2) = distribute t1 `CtEq` distribute t2 + distribute c = c + data RankState = RankState { rankBinVars :: Map VName VName, rankCounter :: !Int, @@ -101,8 +118,19 @@ mkLinearProg counter cs = rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) rankAnalysis counter cs = do + traceM $ unlines $ concat $ map (\c -> [prettyString c, show c]) cs' + traceM $ prettyString prog (_size, ranks) <- branchAndBound lp - pure $ (ranks V.!) <$> inv_var_map + pure $ (fromJust . (ranks V.!?)) <$> inv_var_map where - (lp, var_map) = linearProgToLP $ mkLinearProg counter cs + splitFuncs + ( CtEq + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) + ) = + splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq (toType t1r) (toType t2r)) + splitFuncs c = [c] + cs' = foldMap (splitFuncs . distribute) cs + prog = mkLinearProg counter cs' + (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index a9f6444980..3f233a05c6 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,6 +59,7 @@ import Language.Futhark import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM +import Language.Futhark.TypeChecker.Rank import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify (Level) import Prelude hiding (mod) @@ -1070,3 +1071,35 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') + +-- checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do +-- bindParams tparams params $ \params' -> do +-- body' <- checkExp body +-- +-- retdecl' <- checkRetDecl body' retdecl +-- +-- cts <- gets termConstraints +-- +-- counter <- gets termCounter +-- +-- traceM $ unlines $ map prettyString cts +-- +-- case rankAnalysis counter cts of +-- Nothing -> error "" +-- Just rank_map -> do +-- tyvars <- gets termTyVars +-- +-- let solution = solve cts tyvars +-- +-- traceM $ +-- unlines +-- [ "# function " <> prettyNameString fname, +-- "## constraints:", +-- unlines $ map prettyString cts, +-- "## tyvars:", +-- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, +-- "## solution:", +-- let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t +-- in either T.unpack (unlines . map p . M.toList) solution +-- ] +-- pure (solution, params', retdecl', body') From e109d19e9cf03b40fb19dbe5f482b47040b1aac0 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 10 Feb 2024 22:41:27 -0800 Subject: [PATCH 020/296] Add LP/ILP unit tests. --- futhark.cabal | 3 + .../Futhark/Solve/BranchAndBoundTests.hs | 120 +++++++++++ unittests/Futhark/Solve/SimplexTests.hs | 189 ++++++++++++++++++ unittests/futhark_tests.hs | 6 +- 4 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 unittests/Futhark/Solve/BranchAndBoundTests.hs create mode 100644 unittests/Futhark/Solve/SimplexTests.hs diff --git a/futhark.cabal b/futhark.cabal index 80006d4cc9..0b66baf4cc 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -523,6 +523,8 @@ test-suite unit Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests Futhark.Pkg.SolveTests Futhark.ProfileTests + Futhark.Solve.BranchAndBoundTests + Futhark.Solve.SimplexTests Language.Futhark.CoreTests Language.Futhark.PrimitiveTests Language.Futhark.SyntaxTests @@ -540,3 +542,4 @@ test-suite unit , tasty-hunit , tasty-quickcheck , text + , vector >=0.12 diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs new file mode 100644 index 0000000000..10867a1bee --- /dev/null +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -0,0 +1,120 @@ +module Futhark.Solve.BranchAndBoundTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.BranchAndBound +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) +import Prelude qualified + +tests :: TestTree +tests = + testGroup + "BranchAndBoundTests" + [ -- testCase "1" $ + -- let lpe = + -- LPE + -- { pc = V.fromList [1, 1, 0, 0, 0], + -- pA = + -- M.fromLists + -- [ [-1, 1, 1, 0, 0], + -- [1, 0, 0, 1, 0], + -- [0, 1, 0, 0, 1] + -- ], + -- pd = V.fromList [1, 3, 2] + -- } + -- in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in branchAndBound lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in branchAndBound lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ branchAndBound lp) $ + case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (11.8 :: Double), + and $ zipWith (==) (V.toList sol) [1, 3] + ], + testCase "5" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> oneIsZero ("b1", "x1") ("b2", "x2") + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in assertBool + (unlines [show $ branchAndBound lp]) + $ case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (10 :: Double) + ] + -- testCase "6" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ] + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs new file mode 100644 index 0000000000..80eee3237e --- /dev/null +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -0,0 +1,189 @@ +module Futhark.Solve.SimplexTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Debug.Trace +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Futhark.Solve.Simplex +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) +import Prelude qualified + +tests :: TestTree +tests = + testGroup + "SimplexTests" + [ testCase "1" $ + let lpe = + LPE + { pc = V.fromList [1, 1, 0, 0, 0], + pA = + M.fromLists + [ [-1, 1, 1, 0, 0], + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 1] + ], + pd = V.fromList [1, 3, 2] + } + in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in simplexLP lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in simplexLP lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (14.08 :: Double), + and $ zipWith approxEq (V.toList sol) [1.3, 3.3] + ], + testCase "5" $ + let lp = + LP + { lpc = V.fromList [0], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [0, 0] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double), + and $ zipWith approxEq (V.toList sol) [0] + ], + testCase "6" $ + let lp = + LP + { lpc = V.fromList [1], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [5, 5] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (5 :: Double), + and $ zipWith approxEq (V.toList sol) [5] + ], + testCase "7" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1", + constraints = + [ var "x1" ~<=~ 10 ~*~ var "b1", + var "b1" ~+~ var "b2" ~<=~ constant 1 + ] + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in trace + (unlines [show prog, show lp, show idxmap, show lpe]) + ( assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (10 :: Double), + and $ zipWith (==) (V.toList sol) [1, 0, 10] + ] + ), + testCase "8" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> oneIsZero ("b1", "x1") ("b2", "x2") + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in trace + (unlines [show prog, show lp, show idxmap, show lpe]) + ( assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (15 :: Double) + ] + ), + testCase "9" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in trace + (unlines [show prog, show lp, show idxmap, show lpe]) + ( assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (15 :: Double) + ] + ) + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/futhark_tests.hs b/unittests/futhark_tests.hs index 32e22272cf..79986794b7 100644 --- a/unittests/futhark_tests.hs +++ b/unittests/futhark_tests.hs @@ -10,6 +10,8 @@ import Futhark.IR.Syntax.CoreTests qualified import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified +import Futhark.Solve.BranchAndBoundTests qualified +import Futhark.Solve.SimplexTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SyntaxTests qualified import Language.Futhark.TypeCheckerTests qualified @@ -31,7 +33,9 @@ allTests = Language.Futhark.PrimitiveTests.tests, Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests.tests, Futhark.Analysis.AlgSimplifyTests.tests, - Language.Futhark.TypeCheckerTests.tests + Language.Futhark.TypeCheckerTests.tests, + Futhark.Solve.SimplexTests.tests, + Futhark.Solve.BranchAndBoundTests.tests ] main :: IO () From fd2cc94f34f4998263bd72dd93e95da5e39c9a9d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 11:28:02 -0800 Subject: [PATCH 021/296] Add conversion to PuLP for easier debugging. --- src/Futhark/Solve/LP.hs | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index d5c0ee6c5e..af9265f458 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -24,9 +24,11 @@ module Futhark.Solve.LP (~<=~), (~>=~), rowEchelonLPE, + linearProgToPulp, ) where +import Data.Char (isAscii) import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as Map @@ -97,7 +99,7 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where ( \(k, a) -> case k of Nothing -> pretty a - Just k' -> (if a == 1 then mempty else pretty a <> "·") <> prettyName k' + Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' ) $ Map.toList m @@ -109,7 +111,7 @@ data CType = Equal | LessEq deriving (Show, Eq) instance Pretty CType where - pretty Equal = "=" + pretty Equal = "==" pretty LessEq = "<=" -- | A constraint for a linear program. @@ -145,6 +147,38 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where indent 2 $ vcat $ map pretty cs ] +-- For debugging +linearProgToPulp :: (Unbox a, IsName v, Ord v, Pretty a, Eq a, Num a) => LinearProg v a -> String +linearProgToPulp prog = + map rm_subscript $ + unlines + [ "from pulp import *", + "prob = LpProblem('', " <> lptype <> ")", + unlines vars, + unlines $ map (("prob += " <>) . prettyString) $ constraints prog, + "status = prob.solve()", + "print(f'status: {status}')", + unlines res + ] + where + lptype = + case optType prog of + Maximize -> "LpMaximize" + Minimize -> "LpMinimize" + prog_vars = Map.elems $ snd $ linearProgToLP prog + vars = + map + ( \v -> + show (prettyName v) + <> " = " + <> "LpVariable(" + <> show (show (prettyName v)) + <> ", lowBound = 0, cat = 'Integer')" + ) + prog_vars + res = map (\v -> "print(f'" <> show (prettyName v) <> ": {value(" <> show (prettyName v) <> ")}')") prog_vars + rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" + bigM :: (Num a) => a bigM = 10 ^ 3 From ffcf337129cf71292d0a0889bcb8ad434c93269f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 21:30:48 +0100 Subject: [PATCH 022/296] Dummy handler for CtAM. --- src/Language/Futhark/TypeChecker/Constraints.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b6b5507283..28567a9a5b 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -171,9 +171,10 @@ unify t1 t2 unify _ _ = Nothing solveCt :: Ct -> SolveM () -solveCt ct = do - let CtEq t1 t2 = ct - solveCt' (t1, t2) +solveCt ct = + case ct of + CtEq t1 t2 -> solveCt' (t1, t2) + CtAM _ _ -> pure () -- Good vibes only. where bad = throwError $ "Unsolvable: " <> prettyText ct solveCt' (t1, t2) = do From 93b11e8acb3fc7fe8b14029f8f1361025b609ecf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 21:30:55 +0100 Subject: [PATCH 023/296] Style fixes. --- src/Language/Futhark/TypeChecker/Constraints.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 28567a9a5b..7f43aaf323 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -34,7 +34,7 @@ data SComp deriving (Eq, Ord, Show) instance Pretty SComp where - pretty (SDim) = "[]" + pretty SDim = "[]" pretty (SVar x) = pretty x instance Pretty (Shape SComp) where @@ -158,7 +158,7 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = Just $ M.elems $ M.intersectionWith (,) fs1 fs2 unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) - | M.keys cs1 == M.keys cs2 = do + | M.keys cs1 == M.keys cs2 = fmap concat . forM (M.elems $ M.intersectionWith (,) cs1 cs2) $ \(ts1, ts2) -> do From de224f7b9e9155fee2b93b34dbec0f15ac04c222 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 21:33:22 +0100 Subject: [PATCH 024/296] Better prettyprinting. --- src/Language/Futhark/TypeChecker/Constraints.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 7f43aaf323..8ef97e2287 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -35,10 +35,10 @@ data SComp instance Pretty SComp where pretty SDim = "[]" - pretty (SVar x) = pretty x + pretty (SVar x) = brackets $ pretty x instance Pretty (Shape SComp) where - pretty = mconcat . map (brackets . pretty) . shapeDims + pretty = mconcat . map pretty . shapeDims -- | The type representation used by the constraint solver. Agnostic -- to sizes. From cbc4356ce144fd6e4cabc5ffc3c3698b6d41c120 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 22:21:58 +0100 Subject: [PATCH 025/296] Better prettyprinting of arrays of functions. --- src/Language/Futhark/Pretty.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index d81553f5b8..81ca4a152f 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -158,7 +158,7 @@ instance (Pretty (Shape dim), Pretty u) => Pretty (ScalarTypeBase dim u) where prettyType :: (Pretty (Shape dim), Pretty u) => Int -> TypeBase dim u -> Doc a prettyType _ (Array u shape at) = - pretty u <> pretty shape <> align (prettyScalarType 1 at) + pretty u <> pretty shape <> align (prettyScalarType 2 at) prettyType p (Scalar t) = prettyScalarType p t From 2a65a7660e4baa500ee04489cc71d66e78894ba3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 11 Feb 2024 22:59:36 +0100 Subject: [PATCH 026/296] Make toType require Size sizes. This makes it harder to accidentally throw away shape variables. --- .../Futhark/TypeChecker/Constraints.hs | 11 +- src/Language/Futhark/TypeChecker/Rank.hs | 5 +- src/Language/Futhark/TypeChecker/Terms2.hs | 148 +++++++++--------- 3 files changed, 89 insertions(+), 75 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 8ef97e2287..a81812def5 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -44,8 +44,10 @@ instance Pretty (Shape SComp) where -- to sizes. type Type = TypeBase SComp NoUniqueness -toType :: TypeBase d u -> Type -toType = bimap (const SDim) (const NoUniqueness) +-- | Careful when using this on something that already has an SComp +-- size: it will throw away information by converting them to SDim. +toType :: TypeBase Size u -> TypeBase SComp u +toType = first (const SDim) data Ct = CtEq Type Type @@ -153,7 +155,10 @@ unify :: Type -> Type -> Maybe [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Just [] unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Just [(t1a, t2a), (toType t1r, toType t2r)] + Just [(t1a, t2a), (t1r', t2r')] + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = Just $ M.elems $ M.intersectionWith (,) fs1 fs2 diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 5ef9f72594..ee9eaafdcc 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -128,7 +128,10 @@ rankAnalysis counter cs = do (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) ) = - splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq (toType t1r) (toType t2r)) + splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq t1r' t2r') + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] cs' = foldMap (splitFuncs . distribute) cs prog = mkLinearProg counter cs' diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 3f233a05c6..c79c67f5e2 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -224,25 +224,26 @@ newTyVar loc desc = newTyVarWith loc desc TyVarFree newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) newType loc desc = tyVarType <$> newTyVar loc desc -newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> TypeBase dim u -> TermM (TypeBase dim u) +newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> Type -> TermM (TypeBase d u) newTypeWithField loc desc k t = - tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k $ toType t) + tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) -newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [Type] -> TermM (TypeBase dim u) +newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) newTypeWithConstr loc desc k ts = - tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts) + tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') + where + ts' = map (`setUniqueness` NoUniqueness) ts -newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase dim u) +newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d u) newTypeOverloaded loc name pts = tyVarType <$> newTyVarWith loc name (TyVarPrim pts) newSVar :: (Located loc) => loc -> Name -> TermM SVar newSVar _loc desc = do i <- incCounter - v <- newID $ mkTypeVarName desc i - pure v + newID $ mkTypeVarName desc i -asStructType :: (Monoid u) => SrcLoc -> TypeBase d u -> TermM (TypeBase Size u) +asStructType :: (Monoid u) => SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do @@ -251,20 +252,20 @@ asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' asStructType loc t = do t' <- newType loc "artificial" - ctEq t' t + ctEq (toType t' `setUniqueness` NoUniqueness) (t `setUniqueness` NoUniqueness) pure t' addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} -ctEq :: TypeBase d1 u1 -> TypeBase d2 u2 -> TermM () +ctEq :: TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () ctEq t1 t2 = -- As a minor optimisation, do not add constraint if the types are -- equal. unless (t1' == t2') $ addCt $ CtEq t1' t2' where - t1' = toType t1 - t2' = toType t2 + t1' = t1 `setUniqueness` NoUniqueness + t2' = t2 `setUniqueness` NoUniqueness ctAM :: SVar -> SVar -> TermM () ctAM r m = addCt $ CtAM r m @@ -377,24 +378,23 @@ lookupMod qn@(QualName _ name) = do lookupVar :: SrcLoc -> QualName VName -> TermM StructType lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn - asStructType loc - =<< case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do - if null tparams && null qs - then pure t - else do - (tnames, t') <- instTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newType loc "t" - pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeOverloaded loc "t" ts - let (pts', rt') = instOverloaded (argtype :: StructType) pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams t) -> do + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newType loc "t" + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeOverloaded loc "t" ts + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType pts' $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -452,9 +452,10 @@ checkPat' p@(TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") - ctEq (Scalar (tupleRecord ps_t)) t - checkPat' p $ Ascribed $ toParam Observe $ Scalar $ tupleRecord ps_t + ps_t :: [Type] <- replicateM (length ps) (newType loc "t") + ctEq (Scalar (tupleRecord ps_t)) (toType t) + st <- asStructType loc $ Scalar $ tupleRecord ps_t + checkPat' p $ Ascribed $ toParam Observe st checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) (Ascribed t) @@ -462,9 +463,10 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) L.sort (map fst p_fs) == L.sort (M.keys t_fs) = RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do - p_fs' <- traverse (const $ newType loc "t") $ M.fromList p_fs - ctEq (Scalar (Record p_fs') :: ParamType) t - checkPat' p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') + p_fs' :: M.Map Name Type <- traverse (const $ newType loc "t") $ M.fromList p_fs + ctEq (Scalar (Record p_fs')) $ toType t + st <- asStructType loc $ Scalar (Record p_fs') + checkPat' p $ Ascribed $ toParam Observe st where check t_fs = traverse (uncurry checkPat') $ @@ -478,7 +480,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do case maybe_outer_t of Ascribed outer_t -> do - ctEq st outer_t + ctEq (toType st) (toType outer_t) PatAscription <$> checkPat' p (Ascribed (resToParam st)) <*> pure t' @@ -490,7 +492,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - addCt $ CtEq (toType t') (toType t) + ctEq (toType t') (toType t) pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -513,12 +515,14 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do p_t <- newType (srclocOf p) "t" checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' - ctEq t' t - pure $ PatConstr n (Info t') ps' loc + ctEq t' (toType t) + t'' <- asStructType loc t' + pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps t <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' - pure $ PatConstr n (Info t) ps' loc + t' <- asStructType loc t + pure $ PatConstr n (Info $ toParam Observe t') ps' loc checkPat :: PatBase NoInfo VName (TypeBase Size u) -> @@ -603,11 +607,11 @@ checkApply loc _ ftype arg = do where toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = - pure (a, toType b) + pure (a, b `setUniqueness` NoUniqueness) split ftype' = do a <- newType loc "arg" b <- newTyVar loc "res" - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] (tyVarType b) + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b pure (a, tyVarType b) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] @@ -635,7 +639,7 @@ mustHaveFields loc t [f] ve_t = do ctEq t rt mustHaveFields loc t (f : fs) ve_t = do ft :: Type <- newType loc "ft" - rt <- newTypeWithField loc "rt" f ft + rt :: Type <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t ctEq t rt @@ -660,7 +664,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq c_t cs_t + ctEq (toType c_t) (toType cs_t) pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -698,7 +702,7 @@ checkRetDecl :: checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te - ctEq (typeOf body) st + ctEq (expType body) (toType st) pure $ Just te' checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) @@ -749,7 +753,7 @@ checkExp (ArrayLit es _ loc) = do et <- newType loc "et" es' <- forM es $ \e -> do e' <- checkExp e - ctEq (typeOf e') et + ctEq (expType e') (toType et) pure e' let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et pure $ ArrayLit es' (Info arr_t) loc @@ -804,7 +808,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do onArg (i, f_t) (_, arg) = do arg' <- checkExp arg - (rt, am) <- checkApply loc (fname, i) (toType f_t) arg' + (rt, am) <- checkApply loc (fname, i) f_t arg' pure ( (i + 1, rt), (Info (Nothing, am), arg') @@ -831,7 +835,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do let t1 = typeOf e' t2 <- newType loc "t2" rt <- newType loc "rt" - ctEq optype $ foldFunType [toParam Observe t1, t2] $ RetType [] rt + ctEq (toType optype) $ toType $ foldFunType [toParam Observe t1, t2] $ RetType [] $ rt `setUniqueness` Nonunique pure $ OpSectionLeft op @@ -849,7 +853,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do t1 <- newType loc "t" let t2 = typeOf e' rt <- newType loc "rt" - ctEq optype $ foldFunType [t1, toParam Observe t2] $ RetType [] rt + ctEq (toType optype) $ toType $ foldFunType [t1, toParam Observe t2] $ RetType [] $ rt `setUniqueness` Nonunique pure $ OpSectionRight op @@ -866,7 +870,7 @@ checkExp (ProjectSection fields NoInfo loc) = do a <- newType loc "a" b <- newType loc "b" mustHaveFields loc (toType a) fields (toType b) - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] b + let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc -- checkExp (Lambda params body retdecl NoInfo loc) = do @@ -916,7 +920,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do maybe_step' <- traverse checkExp' maybe_step end' <- traverse checkExp' end range_t <- newType loc "range" - ctEq range_t $ arrayOfRank 1 (toType (typeOf start')) + ctEq (toType range_t) (arrayOfRank 1 (expType start')) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] where checkExp' = require "use in range expression" anyIntType <=< checkExp @@ -924,9 +928,10 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do checkExp (Project k e NoInfo loc) = do e' <- checkExp e kt <- newType loc "kt" - t <- newTypeWithField loc "t" k kt - ctEq (typeOf e') t - pure $ Project k e' (Info kt) loc + t :: Type <- newTypeWithField loc "t" k kt + ctEq (expType e') t + kt' <- asStructType loc kt + pure $ Project k e' (Info kt') loc -- checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src @@ -938,11 +943,12 @@ checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice index_arg_t <- newType loc "index" index_elem_t <- newType loc "index_elem" - index_res_t <- newType loc "index_res" + index_res_t :: Type <- newType loc "index_res" let num_slices = length $ filter isSlice slice - ctEq index_arg_t $ arrayOfRank num_slices index_elem_t + ctEq (toType index_arg_t) $ arrayOfRank num_slices index_elem_t ctEq index_res_t $ arrayOfRank (length slice) index_elem_t - let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ RetType [] index_res_t + index_res_t' <- asStructType loc index_res_t + let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' pure $ IndexSection slice' (Info ft) loc -- checkExp (AppExp (Index e slice loc) _) = do @@ -951,8 +957,8 @@ checkExp (AppExp (Index e slice loc) _) = do index_t <- newType loc "index" index_elem_t <- newType loc "index_elem" let num_slices = length $ filter isSlice slice - ctEq index_t $ arrayOfRank num_slices index_elem_t - ctEq (typeOf e') $ arrayOfRank (length slice) index_elem_t + ctEq (toType index_t) $ arrayOfRank num_slices index_elem_t + ctEq (expType e') $ arrayOfRank (length slice) index_elem_t pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) -- checkExp (Update src slice ve loc) = do @@ -961,8 +967,8 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq (typeOf src') $ arrayOfRank (length slice) update_elem_t - ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t + ctEq (expType src') $ arrayOfRank (length slice) update_elem_t + ctEq (expType ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do @@ -973,8 +979,8 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" - ctEq src_t $ arrayOfRank (length slice) update_elem_t - ctEq (typeOf ve') $ arrayOfRank num_slices update_elem_t + ctEq (toType src_t) $ arrayOfRank (length slice) update_elem_t + ctEq (expType ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) @@ -984,8 +990,8 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do e2' <- checkExp e2 e3' <- checkExp e3 - ctEq (typeOf e1') (Scalar (Prim Bool) :: Type) - ctEq (typeOf e2') (typeOf e3') + ctEq (expType e1') (Scalar (Prim Bool)) + ctEq (expType e2') (expType e3') pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) -- @@ -1004,17 +1010,17 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do let i' = Ident i (Info (typeOf bound')) iloc bind [i'] $ do body' <- checkExp body - ctEq (typeOf arg') (typeOf body') + ctEq (expType arg') (expType body') pure (For i' bound', body') While cond -> do cond' <- checkExp cond body' <- checkExp body - ctEq (typeOf arg') (typeOf body') + ctEq (expType arg') (expType body') pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" - ctEq (typeOf arr') $ arrayOfRank 1 (toType elem_t) + ctEq (expType arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') @@ -1026,12 +1032,12 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do checkExp (Ascript e te loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te - ctEq (typeOf e') st + ctEq (expType e') (toType st) pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te - ctEq (typeOf e') st + ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc checkValDef :: From d2e28288c2f23df68cf515125c3983b68cbe3794 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 19:35:25 -0800 Subject: [PATCH 027/296] Fall back to non-integral row echelon transformation for now. --- src/Futhark/Solve/Matrix.hs | 57 +++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs index 90e1a3e126..ae3bdf6b7c 100644 --- a/src/Futhark/Solve/Matrix.hs +++ b/src/Futhark/Solve/Matrix.hs @@ -281,33 +281,8 @@ update_ m upds = (nrows m) (ncols m) --- TODO: maintain integrality of entries in the matrix --- rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a --- rowEchelon = rowEchelon' 0 0 --- where --- rowEchelon' h k m@(Matrix _ nr nc) --- | h < nr && k < nc = --- if m ! (pivot_row, k) == 0 --- then rowEchelon' h (k + 1) m --- else rowEchelon' (h + 1) (k + 1) clear_rows_below --- | otherwise = m --- where --- pivot_row = --- fst $ --- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ --- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] --- m' = swapRows h pivot_row m --- clear_rows_below = --- update m' $ --- V.fromList $ --- [((i, k), 0) | i <- [h + 1 .. nr - 1]] --- ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) --- | i <- [h + 1 .. nr - 1], --- let f = m' ! (i, k) / m' ! (h, k), --- j <- [k + 1 .. nc - 1] --- ] - -rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +-- This version doesn't maintain integrality of the entries. +rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a rowEchelon = rowEchelon' 0 0 where rowEchelon' h k m@(Matrix _ nr nc) @@ -326,11 +301,37 @@ rowEchelon = rowEchelon' 0 0 update m' $ V.fromList $ [((i, k), 0) | i <- [h + 1 .. nr - 1]] - ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) + ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) | i <- [h + 1 .. nr - 1], + let f = m' ! (i, k) / m' ! (h, k), j <- [k + 1 .. nc - 1] ] +-- TODO: fix. Something's wrong here, causes huge blow-up. +-- rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +-- rowEchelon = rowEchelon' 0 0 +-- where +-- rowEchelon' h k m@(Matrix _ nr nc) +-- | h < nr && k < nc = +-- if m ! (pivot_row, k) == 0 +-- then rowEchelon' h (k + 1) m +-- else rowEchelon' (h + 1) (k + 1) clear_rows_below +-- | otherwise = m +-- where +-- pivot_row = +-- fst $ +-- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ +-- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] +-- m' = swapRows h pivot_row m +-- clear_rows_below = +-- update m' $ +-- V.fromList $ +-- [((i, k), 0) | i <- [h + 1 .. nr - 1]] +-- ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) +-- | i <- [h + 1 .. nr - 1], +-- j <- [k + 1 .. nc - 1] +-- ] + filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a filterRows p = fromVectors . filter p . toList From 18665d86a68e28e8b8d24cdd1bf3286dbefcd1af Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 19:42:50 -0800 Subject: [PATCH 028/296] Some new tests. --- .../Futhark/Solve/BranchAndBoundTests.hs | 68 +++++++--- unittests/Futhark/Solve/SimplexTests.hs | 124 ++++++++++++------ 2 files changed, 131 insertions(+), 61 deletions(-) diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs index 10867a1bee..ed7e04c715 100644 --- a/unittests/Futhark/Solve/BranchAndBoundTests.hs +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -92,28 +92,54 @@ tests = Just (z, sol) -> and [ z `approxEq` (10 :: Double) + ], + -- testCase "6" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ] + + testCase "10" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "R2" ~+~ var "M3", + constraints = + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 + ] + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in assertBool + (unlines [show $ branchAndBound lp]) + $ case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double) ] - -- testCase "6" $ - -- let prog = - -- LinearProg - -- { optType = Maximize, - -- objective = var "x1" ~+~ var "x2", - -- constraints = - -- [ var "x1" ~<=~ constant 10, - -- var "x2" ~<=~ constant 5 - -- ] - -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) - -- } - -- (lp, idxmap) = linearProgToLP prog - -- lpe = convert lp - -- in assertBool - -- (unlines [show $ branchAndBound lp]) - -- $ case branchAndBound lp of - -- Nothing -> False - -- Just (z, sol) -> - -- and - -- [ z `approxEq` (10 :: Double) - -- ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs index 80eee3237e..1a52203d12 100644 --- a/unittests/Futhark/Solve/SimplexTests.hs +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -123,18 +123,15 @@ tests = } (lp, idxmap) = linearProgToLP prog lpe = convert lp - in trace - (unlines [show prog, show lp, show idxmap, show lpe]) - ( assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (10 :: Double), - and $ zipWith (==) (V.toList sol) [1, 0, 10] - ] - ), + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (10 :: Double), + and $ zipWith (==) (V.toList sol) [1, 0, 10] + ], testCase "8" $ let prog = LinearProg @@ -148,41 +145,88 @@ tests = } (lp, idxmap) = linearProgToLP prog lpe = convert lp - in trace - (unlines [show prog, show lp, show idxmap, show lpe]) - ( assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (15 :: Double) - ] - ), - testCase "9" $ + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (15 :: Double) + ], + -- testCase "9" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in trace + -- (unlines [show prog, show lp, show idxmap, show lpe]) + -- ( assertBool + -- (unlines [show $ simplexLP lp]) + -- $ case simplexLP lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (15 :: Double) + -- ] + -- ), + testCase "10" $ let prog = LinearProg - { optType = Maximize, - objective = var "x1" ~+~ var "x2", + { optType = Minimize, + objective = var "R2" ~+~ var "M3", constraints = - [ var "x1" ~<=~ constant 10, - var "x2" ~<=~ constant 5 + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 ] - <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) } (lp, idxmap) = linearProgToLP prog lpe = convert lp - in trace - (unlines [show prog, show lp, show idxmap, show lpe]) - ( assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (15 :: Double) - ] - ) + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double) + ], + testCase "11" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "4R" ~+~ var "5M", + constraints = + [ var "6artifical" ~==~ constant 1 ~+~ var "2t", + constant 1 ~+~ var "3num" ~==~ constant 1 ~+~ var "2t", + var "0b_R" ~<=~ constant 1, + var "1b_M" ~<=~ constant 1, + var "4R" ~<=~ 1000 ~*~ var "0b_R", + var "5M" ~<=~ 1000 ~*~ var "1b_M", + var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 + ] + } + (lp, idxmap) = linearProgToLP prog + lpe = convert lp + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + and + [ z `approxEq` (0 :: Double) + ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool From 6380dd1c24db872b2fb50c06274446d01cfb0c7d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 20:14:05 -0800 Subject: [PATCH 029/296] Use frame-based AUTOMAP; removes need to distribute over function tys. --- src/Language/Futhark/Prop.hs | 6 +++++ src/Language/Futhark/TypeChecker/Rank.hs | 5 ++-- src/Language/Futhark/TypeChecker/Terms2.hs | 31 ++++++++++++---------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 181f5d4135..65602c4689 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -32,6 +32,7 @@ module Language.Futhark.Prop funType, stripExp, similarExps, + frameOf, -- * Queries on patterns and params patIdents, @@ -1435,6 +1436,11 @@ similarExps (IndexSection slice1 _ _) (IndexSection slice2 _ _) = similarSlices slice1 slice2 similarExps _ _ = Nothing +frameOf :: Exp -> Shape Size +frameOf (AppExp (Apply _ args _) _) = + ((\(_, am) -> autoFrame am) . unInfo . fst) $ NE.last args +frameOf _ = mempty + -- | An identifier with type- and aliasing information. type Ident = IdentBase Info VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index ee9eaafdcc..8031982f21 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -118,8 +118,7 @@ mkLinearProg counter cs = rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) rankAnalysis counter cs = do - traceM $ unlines $ concat $ map (\c -> [prettyString c, show c]) cs' - traceM $ prettyString prog + traceM $ unlines ["rankAnalysis prog:", prettyString prog] (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map where @@ -133,7 +132,7 @@ rankAnalysis counter cs = do t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] - cs' = foldMap (splitFuncs . distribute) cs + cs' = foldMap splitFuncs cs prog = mkLinearProg counter cs' (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index c79c67f5e2..fda11bd604 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -590,21 +590,25 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Exp -> TermM (Type, AutoMap) -checkApply loc _ ftype arg = do +checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) +checkApply loc _ ftype fframe arg = do (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] r) unit_info mempty + lhs = arrayOf (toShape (SVar r) <> (toSComp <$> frameOf arg)) $ toType $ typeOf arg + rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a ctAM r m - ctEq (arrayOf (toShape $ SVar r) $ toType $ typeOf arg) (arrayOf (toShape $ SVar m) a) + ctEq lhs rhs pure - ( arrayOf (toShape $ SVar m) b, - AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = mempty} + ( b, + AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where + toSComp (Var (QualName [] x) _ _) = SVar x + toSComp _ = error "" toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) @@ -797,7 +801,7 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - ((_, rt), args') <- mapAccumLM onArg (0, expType fe') args + ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args rt' <- asStructType loc rt pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] where @@ -806,11 +810,11 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do Var v _ _ -> Just v _ -> Nothing - onArg (i, f_t) (_, arg) = do + onArg (i, f_t, f_f) (_, arg) = do arg' <- checkExp arg - (rt, am) <- checkApply loc (fname, i) f_t arg' + (rt, am) <- checkApply loc (fname, i) f_t f_f arg' pure - ( (i + 1, rt), + ( (i + 1, rt, autoFrame am), (Info (Nothing, am), arg') ) -- @@ -819,8 +823,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e1' <- checkExp e1 e2' <- checkExp e2 - (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) e1' - (rt2, am2) <- checkApply loc (Just op, 1) rt1 e2' + (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) mempty e1' + (rt2, am2) <- checkApply loc (Just op, 1) rt1 mempty e2' rt2' <- asStructType loc rt2 pure $ @@ -831,7 +835,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - void $ checkApply loc (Just op, 0) (toType optype) e' + void $ checkApply loc (Just op, 0) (toType optype) mempty e' let t1 = typeOf e' t2 <- newType loc "t2" rt <- newType loc "rt" @@ -1088,11 +1092,10 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do -- -- counter <- gets termCounter -- --- traceM $ unlines $ map prettyString cts --- -- case rankAnalysis counter cts of -- Nothing -> error "" -- Just rank_map -> do +-- traceM $ prettyString $ M.toList rank_map -- tyvars <- gets termTyVars -- -- let solution = solve cts tyvars From 37236d11dfcf624f9af0adcb4865de36e80cf5b1 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 20:27:26 -0800 Subject: [PATCH 030/296] Support `TyVarInfo` info in rank analysis && use rank analysis in the checker. --- src/Language/Futhark/TypeChecker/Rank.hs | 24 ++++++-- src/Language/Futhark/TypeChecker/Terms2.hs | 67 +++++++--------------- 2 files changed, 40 insertions(+), 51 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 8031982f21..adc0d32f72 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -24,6 +24,9 @@ type ScalarType = ScalarTypeBase SComp NoUniqueness class Rank a where rank :: a -> LSum +instance Rank VName where + rank = var + instance Rank SComp where rank SDim = constant 1 rank (SVar v) = var v @@ -98,8 +101,14 @@ addCt (CtAM r m) = do b_m <- binVar m addConstraints $ oneIsZero (b_r, r) (b_m, m) -mkLinearProg :: Int -> [Ct] -> LinearProg -mkLinearProg counter cs = +addTyVarInfo :: TyVar -> TyVarInfo -> RankM () +addTyVarInfo tv (TyVarFree) = pure () +addTyVarInfo tv (TyVarPrim _) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo _ _ = error "Unhandled" + +mkLinearProg :: Int -> [Ct] -> TyVars -> LinearProg +mkLinearProg counter cs tyVars = LP.LinearProg { optType = Minimize, objective = @@ -114,10 +123,13 @@ mkLinearProg counter cs = rankCounter = counter, rankConstraints = mempty } - finalState = flip execState initState $ runRankM $ mapM_ addCt cs + buildLP = do + mapM_ addCt cs + mapM_ (uncurry addTyVarInfo) $ M.toList tyVars + finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: Int -> [Ct] -> Maybe (Map VName Int) -rankAnalysis counter cs = do +rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe (Map VName Int) +rankAnalysis counter cs tyVars = do traceM $ unlines ["rankAnalysis prog:", prettyString prog] (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map @@ -133,6 +145,6 @@ rankAnalysis counter cs = do t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] cs' = foldMap splitFuncs cs - prog = mkLinearProg counter cs' + prog = mkLinearProg counter cs' tyVars (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index fda11bd604..d7484d7407 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1065,50 +1065,27 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do retdecl' <- checkRetDecl body' retdecl cts <- gets termConstraints + + counter <- gets termCounter + tyvars <- gets termTyVars - let solution = solve cts tyvars - - traceM $ - unlines - [ "# function " <> prettyNameString fname, - "## constraints:", - unlines $ map prettyString cts, - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, - "## solution:", - let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList) solution - ] - pure (solution, params', retdecl', body') - --- checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do --- bindParams tparams params $ \params' -> do --- body' <- checkExp body --- --- retdecl' <- checkRetDecl body' retdecl --- --- cts <- gets termConstraints --- --- counter <- gets termCounter --- --- case rankAnalysis counter cts of --- Nothing -> error "" --- Just rank_map -> do --- traceM $ prettyString $ M.toList rank_map --- tyvars <- gets termTyVars --- --- let solution = solve cts tyvars --- --- traceM $ --- unlines --- [ "# function " <> prettyNameString fname, --- "## constraints:", --- unlines $ map prettyString cts, --- "## tyvars:", --- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, --- "## solution:", --- let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t --- in either T.unpack (unlines . map p . M.toList) solution --- ] --- pure (solution, params', retdecl', body') + case rankAnalysis counter cts tyvars of + Nothing -> error "" + Just rank_map -> do + traceM $ prettyString $ M.toList rank_map + + let solution = solve cts tyvars + + traceM $ + unlines + [ "# function " <> prettyNameString fname, + "## constraints:", + unlines $ map prettyString cts, + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## solution:", + let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList) solution + ] + pure (solution, params', retdecl', body') From 8e8db497ddf13548d6f5ed6112564a6abb1506a4 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 11 Feb 2024 22:43:45 -0800 Subject: [PATCH 031/296] Basic support for substituting in rank info. --- .../Futhark/TypeChecker/Constraints.hs | 4 ++ src/Language/Futhark/TypeChecker/Rank.hs | 64 ++++++++++++++++++- src/Language/Futhark/TypeChecker/Terms2.hs | 10 ++- 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index a81812def5..d5b13aa049 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -8,6 +8,7 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, + TyVarSol (..), Solution, solve, ) @@ -70,6 +71,8 @@ data TyVarInfo TyVarRecord (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum (M.Map Name [Type]) + | -- | Must have at least this rank. + TyVarRank Int deriving (Show) instance Pretty TyVarInfo where @@ -77,6 +80,7 @@ instance Pretty TyVarInfo where pretty (TyVarPrim pts) = "∈" <+> pretty pts pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs + pretty (TyVarRank x) = "rank ≥" <+> pretty x type TyVar = VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index adc0d32f72..9a4eca4010 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -6,6 +6,8 @@ import Data.Map qualified as M import Data.Maybe import Data.Vector.Unboxed qualified as V import Debug.Trace +-- import Futhark.FreshNames qualified as FreshNames +-- import Futhark.MonadFreshNames hiding (newName) import Futhark.Solve.BranchAndBound import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP @@ -13,6 +15,8 @@ import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints +-- import Language.Futhark.TypeChecker.Monad (mkTypeVarName) + type LSum = LP.LSum VName Double type Constraint = LP.Constraint VName Double @@ -128,12 +132,20 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe (Map VName Int) +rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars) rankAnalysis counter cs tyVars = do traceM $ unlines ["rankAnalysis prog:", prettyString prog] (_size, ranks) <- branchAndBound lp - pure $ (fromJust . (ranks V.!?)) <$> inv_var_map + let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map + let (cs', SubstState tyVars') = + flip runState (SubstState mempty) $ + runSubstM $ + substRanks rank_map $ + filter (not . isCtAM) cs + pure (cs', tyVars <> tyVars') where + isCtAM (CtAM {}) = True + isCtAM _ = False splitFuncs ( CtEq (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) @@ -148,3 +160,51 @@ rankAnalysis counter cs tyVars = do prog = mkLinearProg counter cs' tyVars (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] + +newtype SubstM a = SubstM {runSubstM :: State SubstState a} + deriving (Functor, Applicative, Monad, MonadState SubstState) + +data SubstState = SubstState + { substTyVars :: TyVars + } + +rankToShape :: Map VName Int -> VName -> Shape SComp +rankToShape rs x = Shape $ replicate (rs M.! x) SDim + +addRankInfo :: Map VName Int -> TyVar -> SubstM () +addRankInfo rs t = + modify $ \s -> s {substTyVars = M.insert t (TyVarRank $ rs M.! t) $ substTyVars s} + +class SubstRanks a where + substRanks :: Map VName Int -> a -> SubstM a + +instance (SubstRanks a) => SubstRanks [a] where + substRanks rs = mapM (substRanks rs) + +instance SubstRanks (Shape SComp) where + substRanks rs = pure . foldMap instDim + where + instDim (SDim) = Shape $ pure SDim + instDim (SVar x) = rankToShape rs x + +instance SubstRanks (TypeBase SComp u) where + substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) + | rs M.! x > 0 = do + addRankInfo rs x + pure t + substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanks rs ta + tr' <- substRanks rs tr + pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) + substRanks rs (Array u shape t) = do + shape' <- substRanks rs shape + t' <- substRanks rs (Scalar t) + pure $ Array u (shape' <> arrayShape t') (scalarType t') + where + scalarType (Array _ _ t) = t + scalarType (Scalar t) = t + substRanks _ t = pure t + +instance SubstRanks Ct where + substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 + substRanks _ _ = error "" diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d7484d7407..019e0b0d8b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1072,18 +1072,16 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do case rankAnalysis counter cts tyvars of Nothing -> error "" - Just rank_map -> do - traceM $ prettyString $ M.toList rank_map - - let solution = solve cts tyvars + Just (cts', tyvars') -> do + let solution = solve cts' tyvars' traceM $ unlines [ "# function " <> prettyNameString fname, "## constraints:", - unlines $ map prettyString cts, + unlines $ map prettyString cts', "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution From d4d19138c3db9cc06b9560018aa1425f5f9f42b7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 11:54:38 +0100 Subject: [PATCH 032/296] Correct construction of array types. --- src/Language/Futhark/Prop.hs | 4 +++- src/Language/Futhark/TypeChecker/Rank.hs | 5 +---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 65602c4689..cc5c40268a 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -316,7 +316,9 @@ arrayOfWithAliases :: arrayOfWithAliases u shape2 (Array _ shape1 et) = Array u (shape2 <> shape1) et arrayOfWithAliases u shape (Scalar t) = - Array u shape (second (const mempty) t) + if shapeRank shape == 0 + then Scalar t `setUniqueness` u + else Array u shape (second (const mempty) t) -- | @stripArray n t@ removes the @n@ outermost layers of the array. -- Essentially, it is the type of indexing an array of type @t@ with diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9a4eca4010..b282d704c3 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -199,10 +199,7 @@ instance SubstRanks (TypeBase SComp u) where substRanks rs (Array u shape t) = do shape' <- substRanks rs shape t' <- substRanks rs (Scalar t) - pure $ Array u (shape' <> arrayShape t') (scalarType t') - where - scalarType (Array _ _ t) = t - scalarType (Scalar t) = t + pure $ arrayOfWithAliases u shape' t' substRanks _ t = pure t instance SubstRanks Ct where From 320d95444677ff4d6e4b0c5aa06123ee672338b3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:09:36 +0100 Subject: [PATCH 033/296] Forget size variables here. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 019e0b0d8b..bcb97efcbf 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -387,7 +387,7 @@ lookupVar loc qn@(QualName qs name) = do else do (tnames, t') <- instTypeScheme qn loc tparams t outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' + asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newType loc "t" pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool From 46ba3f99e7216f9f7c07d5e5e270f46d88032a35 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:46:27 +0100 Subject: [PATCH 034/296] Basic tracking of levels. --- .../Futhark/TypeChecker/Constraints.hs | 69 +++++++++++-------- src/Language/Futhark/TypeChecker/Rank.hs | 10 +-- src/Language/Futhark/TypeChecker/Terms.hs | 8 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 8 ++- 4 files changed, 56 insertions(+), 39 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index d5b13aa049..d3d145c57e 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -18,6 +18,7 @@ import Control.Monad.Except import Control.Monad.State import Data.Bifunctor import Data.Map qualified as M +import Data.Maybe import Data.Text qualified as T import Futhark.IR.Pretty import Futhark.Util.Pretty @@ -84,30 +85,31 @@ instance Pretty TyVarInfo where type TyVar = VName --- | If a VName is not in this map, it is assumed to be rigid. -type TyVars = M.Map TyVar TyVarInfo +-- | If a VName is not in this map, it is assumed to be rigid. The +-- integer is the level. +type TyVars = M.Map TyVar (Int, TyVarInfo) data TyVarSol = -- | Has been substituted with this. - TyVarSol Type + TyVarSol Int Type | -- | Replaced by this other type variable. TyVarLink VName | -- | Not substituted yet; has this constraint. - TyVarUnsol TyVarInfo + TyVarUnsol Int TyVarInfo deriving (Show) newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} initialState :: TyVars -> SolverState -initialState tyvars = SolverState $ M.map TyVarUnsol tyvars +initialState tyvars = SolverState $ M.map (uncurry TyVarUnsol) tyvars substTyVars :: (Monoid u) => M.Map TyVar TyVarSol -> TypeBase SComp u -> TypeBase SComp u substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = case M.lookup v m of Just (TyVarLink v') -> substTyVars m $ Scalar $ TypeVar u (QualName qs v') args - Just (TyVarSol t') -> second (const mempty) t' - Just (TyVarUnsol _) -> t + Just (TyVarSol _ t') -> second (const mempty) t' + Just (TyVarUnsol {}) -> t Nothing -> t substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs @@ -118,10 +120,11 @@ substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt -- | A solution maps types to the set of type variables that must be --- substituted with this type. This slightly odd representation is --- needed to encode when two type variables are actually the same --- type. This matters when we start instanting the sizes of the type. -type Solution = M.Map Type [TyVar] +-- substituted with this type, as well as its binding level. This +-- slightly odd representation is needed to encode when two type +-- variables are actually the same type. This matters when we start +-- instanting the sizes of the type. +type Solution = M.Map Type (Int, [TyVar]) solution :: SolverState -> Solution solution s = @@ -132,23 +135,23 @@ solution s = M.toList $ solverTyVars s where - mkSubst (TyVarSol t) = Just (t, []) + mkSubst (TyVarSol lvl t) = Just (lvl, (t, [])) mkSubst _ = Nothing addLinks m (v1, TyVarLink v2) = case M.lookup v2 $ solverTyVars s of Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) _ -> case M.lookup v2 m of Nothing -> m - Just (t, vs) -> M.insert v2 (t, v1 : vs) m + Just (t, (lvl, vs)) -> M.insert v2 (t, (lvl, v1 : vs)) m addLinks m _ = m - adjust (v, (t, vs)) = (t, v : vs) + adjust (v, (lvl, (t, vs))) = (t, (lvl, v : vs)) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) -subTyVar :: VName -> Type -> SolveM () -subTyVar v t = - modify $ \s -> s {solverTyVars = M.insert v (TyVarSol t) $ solverTyVars s} +subTyVar :: VName -> Int -> Type -> SolveM () +subTyVar v lvl t = + modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} linkTyVar :: VName -> VName -> SolveM () linkTyVar v t = @@ -190,13 +193,13 @@ solveCt ct = tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of Just (TyVarLink v') -> flexible v' - Just (TyVarUnsol _) -> True - Just (TyVarSol _) -> False - Nothing -> False + Just (TyVarUnsol lvl _) -> Just lvl + Just (TyVarSol _ _) -> Nothing + Nothing -> Nothing sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of Just (TyVarLink v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (TyVarSol t') -> sub t' + Just (TyVarSol _ t') -> sub t' _ -> t sub t = t case (sub t1, sub t2) of @@ -206,14 +209,22 @@ solveCt ct = | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> bad - (True, False) -> subTyVar v1 t2' - (False, True) -> subTyVar v2 t1' - (True, True) -> linkTyVar v1 v2 - (Scalar (TypeVar _ (QualName [] v1) []), t2') -> - if flexible v1 then subTyVar v1 t2' else bad - (t1', Scalar (TypeVar _ (QualName [] v2) [])) -> - if flexible v2 then subTyVar v2 t1' else bad + (Nothing, Nothing) -> bad + (Just lvl, Nothing) -> subTyVar v1 lvl t2' + (Nothing, Just lvl) -> subTyVar v2 lvl t1' + (Just lvl1, Just lvl2) + | lvl1 <= lvl2 -> linkTyVar v1 v2 + | otherwise -> linkTyVar v2 v1 + (Scalar (TypeVar _ (QualName [] v1) []), t2') + | Just lvl <- flexible v1 -> + subTyVar v1 lvl t2' + | otherwise -> + bad + (t1', Scalar (TypeVar _ (QualName [] v2) [])) + | Just lvl <- flexible v2 -> + subTyVar v2 lvl t1' + | otherwise -> + bad (t1', t2') -> case unify t1' t2' of Nothing -> bad Just eqs -> mapM_ solveCt' eqs diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index b282d704c3..bcb67d2a13 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -105,9 +105,9 @@ addCt (CtAM r m) = do b_m <- binVar m addConstraints $ oneIsZero (b_r, r) (b_m, m) -addTyVarInfo :: TyVar -> TyVarInfo -> RankM () -addTyVarInfo tv (TyVarFree) = pure () -addTyVarInfo tv (TyVarPrim _) = +addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () +addTyVarInfo tv (_, TyVarFree) = pure () +addTyVarInfo tv (_, TyVarPrim _) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo _ _ = error "Unhandled" @@ -173,7 +173,9 @@ rankToShape rs x = Shape $ replicate (rs M.! x) SDim addRankInfo :: Map VName Int -> TyVar -> SubstM () addRankInfo rs t = - modify $ \s -> s {substTyVars = M.insert t (TyVarRank $ rs M.! t) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (lvl, TyVarRank $ rs M.! t) $ substTyVars s} + where + lvl = 0 -- FIXME class SubstRanks a where substRanks :: Map VName Int -> a -> SubstM a diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2a7bbbd6da..22331fe113 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1622,15 +1622,15 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t -addInitialConstraints :: M.Map (TypeBase () NoUniqueness) [VName] -> TermTypeM () +addInitialConstraints :: M.Map (TypeBase () NoUniqueness) (Int, [VName]) -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where - addConstraint v c = modifyConstraints $ M.insert v (0, c) + addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) usage = mkUsage (mempty :: Loc) - f (t, vs) = do + f (t, (lvl, vs)) = do (t', _) <- allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t forM_ vs $ \v -> - addConstraint v $ Constraint (RetType [] t') $ usage $ prettyNameText v + addConstraint v lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index bcb97efcbf..6c2194511a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -202,6 +202,9 @@ runTermM (TermM m) = do incLevel :: TermM a -> TermM a incLevel = local $ \env -> env {termLevel = termLevel env + 1} +curLevel :: TermM Int +curLevel = asks termLevel + incCounter :: TermM Int incCounter = do s <- get @@ -215,7 +218,8 @@ newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar newTyVarWith _loc desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i - modify $ \s -> s {termTyVars = M.insert v info $ termTyVars s} + lvl <- curLevel + modify $ \s -> s {termTyVars = M.insert v (lvl, info) $ termTyVars s} pure v newTyVar :: (Located loc) => loc -> Name -> TermM TyVar @@ -1083,7 +1087,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", - let p (t, vs) = unwords (map prettyNameString vs) <> " => " <> prettyString t + let p (t, (lvl, vs)) = unwords (show [lvl] : map prettyNameString vs) <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') From 957b46232b513f754078ae84fc46bc91e636303f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:51:22 +0100 Subject: [PATCH 035/296] Create sizes at right level. --- src/Language/Futhark/TypeChecker/Terms.hs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 22331fe113..96eb651bb3 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1622,13 +1622,18 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t +-- A hack we need to create size variables for types at the right +-- level. +atLevel :: Int -> TermTypeM a -> TermTypeM a +atLevel lvl = local $ \env -> env {termLevel = lvl} + addInitialConstraints :: M.Map (TypeBase () NoUniqueness) (Int, [VName]) -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) usage = mkUsage (mempty :: Loc) f (t, (lvl, vs)) = do - (t', _) <- allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t + (t', _) <- atLevel lvl $ allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t forM_ vs $ \v -> addConstraint v lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v From ceaa72fa3b570baa7d0dae6fff4e8b5578b57ab2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:55:46 +0100 Subject: [PATCH 036/296] Handle more cases. --- src/Language/Futhark/TypeChecker/Rank.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index bcb67d2a13..d107f6465b 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -40,9 +40,11 @@ instance Rank (Shape SComp) where instance Rank ScalarType where rank Prim {} = constant 0 - rank (TypeVar _ (QualName [] v) []) = var v + rank (TypeVar _ (QualName [] v) []) = var v -- FIXME - might not be a type variable. + rank (TypeVar {}) = constant 0 rank (Arrow {}) = constant 0 - rank t = error $ prettyString t + rank (Record {}) = constant 0 + rank (Sum {}) = constant 0 instance Rank Type where rank (Scalar t) = rank t From 3d8a1a14dc4348986e531d69e85dfb7ac22da80a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 12:57:34 +0100 Subject: [PATCH 037/296] Consistent printing. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index d107f6465b..547c2f2cab 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -136,7 +136,7 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars) rankAnalysis counter cs tyVars = do - traceM $ unlines ["rankAnalysis prog:", prettyString prog] + traceM $ unlines ["## rankAnalysis prog", prettyString prog] (_size, ranks) <- branchAndBound lp let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map let (cs', SubstState tyVars') = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6c2194511a..227642e1b5 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1074,6 +1074,8 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do tyvars <- gets termTyVars + traceM $ "# function " <> prettyNameString fname + case rankAnalysis counter cts tyvars of Nothing -> error "" Just (cts', tyvars') -> do @@ -1081,8 +1083,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do traceM $ unlines - [ "# function " <> prettyNameString fname, - "## constraints:", + [ "## constraints:", unlines $ map prettyString cts', "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', From 64ae342ea8946e19ff94d4720c42daff7407a35e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 13:19:03 +0100 Subject: [PATCH 038/296] Fix construction of solution. --- src/Language/Futhark/TypeChecker/Constraints.hs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index d3d145c57e..af55a6d243 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -17,6 +17,7 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor +import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T @@ -128,12 +129,11 @@ type Solution = M.Map Type (Int, [TyVar]) solution :: SolverState -> Solution solution s = - M.fromList $ - map adjust $ - M.toList $ - foldl addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ - M.toList $ - solverTyVars s + L.foldl' byType mempty $ + M.toList $ + L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ + M.toList $ + solverTyVars s where mkSubst (TyVarSol lvl t) = Just (lvl, (t, [])) mkSubst _ = Nothing @@ -144,7 +144,9 @@ solution s = Nothing -> m Just (t, (lvl, vs)) -> M.insert v2 (t, (lvl, v1 : vs)) m addLinks m _ = m - adjust (v, (lvl, (t, vs))) = (t, (lvl, v : vs)) + byType m (v, (lvl, (t, vs))) = M.insertWith comb t (lvl, v : vs) m + where + comb (lvl1, ts1) (lvl2, ts2) = (min lvl1 lvl2, ts1 <> ts2) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) From ca5c562b92a0162ec2111b17188f4344fbabb924 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 04:44:21 -0800 Subject: [PATCH 039/296] Spaghetti code to better add in rank info. --- .../Futhark/TypeChecker/Constraints.hs | 4 - src/Language/Futhark/TypeChecker/Rank.hs | 197 ++++++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 7 +- 3 files changed, 163 insertions(+), 45 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index af55a6d243..67131e67cc 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -8,7 +8,6 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, - TyVarSol (..), Solution, solve, ) @@ -73,8 +72,6 @@ data TyVarInfo TyVarRecord (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum (M.Map Name [Type]) - | -- | Must have at least this rank. - TyVarRank Int deriving (Show) instance Pretty TyVarInfo where @@ -82,7 +79,6 @@ instance Pretty TyVarInfo where pretty (TyVarPrim pts) = "∈" <+> pretty pts pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs - pretty (TyVarRank x) = "rank ≥" <+> pretty x type TyVar = VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 547c2f2cab..c87e10a402 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -1,21 +1,21 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where +import Control.Monad.Reader import Control.Monad.State import Data.Map (Map) import Data.Map qualified as M import Data.Maybe import Data.Vector.Unboxed qualified as V import Debug.Trace --- import Futhark.FreshNames qualified as FreshNames --- import Futhark.MonadFreshNames hiding (newName) +import Futhark.FreshNames qualified as FreshNames +import Futhark.MonadFreshNames hiding (newName) import Futhark.Solve.BranchAndBound import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints - --- import Language.Futhark.TypeChecker.Monad (mkTypeVarName) +import Language.Futhark.TypeChecker.Monad (mkTypeVarName) type LSum = LP.LSum VName Double @@ -134,17 +134,29 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars) -rankAnalysis counter cs tyVars = do +rankAnalysis :: VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) +rankAnalysis vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] (_size, ranks) <- branchAndBound lp let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map - let (cs', SubstState tyVars') = - flip runState (SubstState mempty) $ - runSubstM $ - substRanks rank_map $ - filter (not . isCtAM) cs - pure (cs', tyVars <> tyVars') + initEnv = + SubstEnv + { envTyVars = tyVars, + envRanks = rank_map + } + + initState = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNameSource = vns, + substCounter = counter + } + (cs', state') = + runSubstM initEnv initState $ + substRanks $ + filter (not . isCtAM) cs + pure (cs', substTyVars state' <> tyVars, substNameSource state', substCounter state') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -163,49 +175,156 @@ rankAnalysis counter cs tyVars = do (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] -newtype SubstM a = SubstM {runSubstM :: State SubstState a} - deriving (Functor, Applicative, Monad, MonadState SubstState) +newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) + deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) -data SubstState = SubstState - { substTyVars :: TyVars +runSubstM :: SubstEnv -> SubstState -> SubstM a -> (a, SubstState) +runSubstM initEnv initState (SubstM m) = + runReader (runStateT m initState) initEnv + +data SubstEnv = SubstEnv + { envTyVars :: TyVars, + envRanks :: Map VName Int } -rankToShape :: Map VName Int -> VName -> Shape SComp -rankToShape rs x = Shape $ replicate (rs M.! x) SDim +data SubstState = SubstState + { substTyVars :: TyVars, + substNewVars :: Map TyVar TyVar, + substNameSource :: VNameSource, + substCounter :: !Int + } -addRankInfo :: Map VName Int -> TyVar -> SubstM () -addRankInfo rs t = - modify $ \s -> s {substTyVars = M.insert t (lvl, TyVarRank $ rs M.! t) $ substTyVars s} +substIncCounter :: SubstM Int +substIncCounter = do + s <- get + put s {substCounter = substCounter s + 1} + pure $ substCounter s + +newTyVar :: TyVar -> SubstM TyVar +newTyVar t = do + i <- substIncCounter + t' <- newID $ mkTypeVarName (baseName t) i + modify $ \s -> s {substNewVars = M.insert t t' $ substNewVars s} + pure t' + where + newID x = do + s <- get + let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 + put $ s {substNameSource = src'} + pure v' + +rankToShape :: VName -> SubstM (Shape SComp) +rankToShape x = do + rs <- asks envRanks + pure $ Shape $ replicate (rs M.! x) SDim + +addRankInfo :: TyVar -> SubstM TyVar +addRankInfo t = do + rs <- asks envRanks + if rs M.! t == 0 + then pure t + else do + new_vars <- gets substNewVars + maybe new_var pure $ new_vars M.!? t where lvl = 0 -- FIXME + new_var = do + t' <- newTyVar t + old_tyvars <- asks envTyVars + case old_tyvars M.!? t of + Nothing -> pure t' + Just info -> do + modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} + pure t' class SubstRanks a where - substRanks :: Map VName Int -> a -> SubstM a + substRanks :: a -> SubstM a instance (SubstRanks a) => SubstRanks [a] where - substRanks rs = mapM (substRanks rs) + substRanks = mapM substRanks instance SubstRanks (Shape SComp) where - substRanks rs = pure . foldMap instDim + substRanks = foldM (\s d -> (s <>) <$> instDim d) mempty where - instDim (SDim) = Shape $ pure SDim - instDim (SVar x) = rankToShape rs x + instDim (SDim) = pure $ Shape $ pure SDim + instDim (SVar x) = rankToShape x instance SubstRanks (TypeBase SComp u) where - substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) - | rs M.! x > 0 = do - addRankInfo rs x - pure t - substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = do - ta' <- substRanks rs ta - tr' <- substRanks rs tr + substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = do + x' <- addRankInfo x + pure $ (Scalar (TypeVar u (QualName [] x') [])) + substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanks ta + tr' <- substRanks tr pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) - substRanks rs (Array u shape t) = do - shape' <- substRanks rs shape - t' <- substRanks rs (Scalar t) + substRanks (Array u shape t) = do + shape' <- substRanks shape + t' <- substRanks $ Scalar t pure $ arrayOfWithAliases u shape' t' - substRanks _ t = pure t + substRanks t = pure t instance SubstRanks Ct where - substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 - substRanks _ _ = error "" + substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 + substRanks _ = error "" + +-- data SubstState = SubstState +-- { substTyVars :: Map TyVar TyVarSol, +-- substNameSource :: VNameSource, +-- substCounter :: !Int +-- } +-- +-- newtype SubstM a = SubstM {runSubstM :: State SubstState a} +-- deriving (Functor, Applicative, Monad, MonadState SubstState) +-- +-- substIncCounter :: SubstM Int +-- substIncCounter = do +-- s <- get +-- put s {substCounter = substCounter s + 1} +-- pure $ substCounter s +-- +-- newTyVar :: Name -> SubstM TyVar +-- newTyVar desc = do +-- i <- substIncCounter +-- newID $ mkTypeVarName desc i +-- where +-- newID x = do +-- s <- get +-- let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 +-- put $ s {substNameSource = src'} +-- pure v' +-- +-- addTyVarSol :: TyVar -> Shape SComp -> SubstM TyVar +-- addTyVarSol t shape = do +-- m <- subsTyVars gets +-- case m M.!? t of +-- Nothing -> do +-- t' <- newTyVar $ baseName t +-- modify $ \s -> s {substTyVars = M.insert t () $ substTyVars s} +-- Just t' -> pure t' +-- +-- rankToShape :: Map VName Int -> VName -> Shape SComp +-- rankToShape rs x = Shape $ replicate (rs M.! x) SDim +-- +-- class SubstRanks a where +-- substRanks :: Map VName Int -> a -> SubstM a +-- +-- instance SubstRanks (Shape SComp) where +-- substRanks rs = pure . foldMap instDim +-- where +-- instDim (SDim) = Shape $ pure SDim +-- instDim (SVar x) = rankToShape rs x +-- +-- instance SubstRanks (TypeBase SComp u) where +-- substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) +-- | rs M.! x > 0 = do +-- t' <- newTyVar $ baseName t +-- t' <- addTyVarSol +-- arrayOfWithAliases u (rankToShape rs x) t +-- substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = +-- Scalar (Arrow u p d (substRanks rs ta) (RetType retdims (substRanks rs tr))) +-- substRanks _ t = t +-- +-- instance SubstRanks Ct where +-- substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 +-- substRanks _ _ = error "" diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 227642e1b5..4fa481140c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1076,9 +1076,12 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do traceM $ "# function " <> prettyNameString fname - case rankAnalysis counter cts tyvars of + vns <- gets termNameSource + + case rankAnalysis vns counter cts tyvars of Nothing -> error "" - Just (cts', tyvars') -> do + Just (cts', tyvars', vns', counter') -> do + modify $ \s -> s {termCounter = counter', termNameSource = vns'} let solution = solve cts' tyvars' traceM $ From e18fe2e7678df35b821fe44f9b1eaab4a90c9232 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 04:50:08 -0800 Subject: [PATCH 040/296] Print out the rank map too. --- src/Language/Futhark/TypeChecker/Rank.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c87e10a402..4293097d96 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -139,7 +139,8 @@ rankAnalysis vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] (_size, ranks) <- branchAndBound lp let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map - initEnv = + traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + let initEnv = SubstEnv { envTyVars = tyVars, envRanks = rank_map From 344b747b9c8bad5acdbf89e1376559d611726c9c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 05:05:55 -0800 Subject: [PATCH 041/296] Forgot to actually generate the new type variable constraints. --- src/Language/Futhark/TypeChecker/Rank.hs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 4293097d96..9412553f82 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -151,13 +151,14 @@ rankAnalysis vns counter cs tyVars = do { substTyVars = mempty, substNewVars = mempty, substNameSource = vns, - substCounter = counter + substCounter = counter, + substNewCts = mempty } (cs', state') = runSubstM initEnv initState $ substRanks $ filter (not . isCtAM) cs - pure (cs', substTyVars state' <> tyVars, substNameSource state', substCounter state') + pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -192,7 +193,8 @@ data SubstState = SubstState { substTyVars :: TyVars, substNewVars :: Map TyVar TyVar, substNameSource :: VNameSource, - substCounter :: !Int + substCounter :: !Int, + substNewCts :: [Ct] } substIncCounter :: SubstM Int @@ -205,7 +207,17 @@ newTyVar :: TyVar -> SubstM TyVar newTyVar t = do i <- substIncCounter t' <- newID $ mkTypeVarName (baseName t) i - modify $ \s -> s {substNewVars = M.insert t t' $ substNewVars s} + shape <- rankToShape t + modify $ \s -> + s + { substNewVars = M.insert t t' $ substNewVars s, + substNewCts = + substNewCts s + ++ [ CtEq + (Scalar (TypeVar mempty (QualName [] t) [])) + (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) + ] + } pure t' where newID x = do From b2427fa241f608f29c8da5f52efe34a24ebd712f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 05:20:32 -0800 Subject: [PATCH 042/296] Bug fixes. --- src/Language/Futhark/TypeChecker/Rank.hs | 85 +++--------------------- 1 file changed, 9 insertions(+), 76 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9412553f82..cf774947ff 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -231,25 +231,20 @@ rankToShape x = do rs <- asks envRanks pure $ Shape $ replicate (rs M.! x) SDim -addRankInfo :: TyVar -> SubstM TyVar +addRankInfo :: TyVar -> SubstM () addRankInfo t = do rs <- asks envRanks - if rs M.! t == 0 - then pure t - else do - new_vars <- gets substNewVars - maybe new_var pure $ new_vars M.!? t + unless (rs M.! t == 0) $ do + new_vars <- gets substNewVars + maybe new_var (const $ pure ()) $ new_vars M.!? t where lvl = 0 -- FIXME new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars - case old_tyvars M.!? t of - Nothing -> pure t' - Just info -> do - modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} - pure t' + let info = old_tyvars M.! t + modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} class SubstRanks a where substRanks :: a -> SubstM a @@ -264,9 +259,8 @@ instance SubstRanks (Shape SComp) where instDim (SVar x) = rankToShape x instance SubstRanks (TypeBase SComp u) where - substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = do - x' <- addRankInfo x - pure $ (Scalar (TypeVar u (QualName [] x') [])) + substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = + addRankInfo x >> pure t substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do ta' <- substRanks ta tr' <- substRanks tr @@ -280,64 +274,3 @@ instance SubstRanks (TypeBase SComp u) where instance SubstRanks Ct where substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" - --- data SubstState = SubstState --- { substTyVars :: Map TyVar TyVarSol, --- substNameSource :: VNameSource, --- substCounter :: !Int --- } --- --- newtype SubstM a = SubstM {runSubstM :: State SubstState a} --- deriving (Functor, Applicative, Monad, MonadState SubstState) --- --- substIncCounter :: SubstM Int --- substIncCounter = do --- s <- get --- put s {substCounter = substCounter s + 1} --- pure $ substCounter s --- --- newTyVar :: Name -> SubstM TyVar --- newTyVar desc = do --- i <- substIncCounter --- newID $ mkTypeVarName desc i --- where --- newID x = do --- s <- get --- let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 --- put $ s {substNameSource = src'} --- pure v' --- --- addTyVarSol :: TyVar -> Shape SComp -> SubstM TyVar --- addTyVarSol t shape = do --- m <- subsTyVars gets --- case m M.!? t of --- Nothing -> do --- t' <- newTyVar $ baseName t --- modify $ \s -> s {substTyVars = M.insert t () $ substTyVars s} --- Just t' -> pure t' --- --- rankToShape :: Map VName Int -> VName -> Shape SComp --- rankToShape rs x = Shape $ replicate (rs M.! x) SDim --- --- class SubstRanks a where --- substRanks :: Map VName Int -> a -> SubstM a --- --- instance SubstRanks (Shape SComp) where --- substRanks rs = pure . foldMap instDim --- where --- instDim (SDim) = Shape $ pure SDim --- instDim (SVar x) = rankToShape rs x --- --- instance SubstRanks (TypeBase SComp u) where --- substRanks rs t@(Scalar (TypeVar u (QualName [] x) [])) --- | rs M.! x > 0 = do --- t' <- newTyVar $ baseName t --- t' <- addTyVarSol --- arrayOfWithAliases u (rankToShape rs x) t --- substRanks rs (Scalar (Arrow u p d ta (RetType retdims tr))) = --- Scalar (Arrow u p d (substRanks rs ta) (RetType retdims (substRanks rs tr))) --- substRanks _ t = t --- --- instance SubstRanks Ct where --- substRanks rs (CtEq t1 t2) = CtEq <$> substRanks rs t1 <*> substRanks rs t2 --- substRanks _ _ = error "" From b00301c840e2be18abe3995c2022650a5cf11dbb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 14:28:03 +0100 Subject: [PATCH 043/296] Unused now. --- src/Language/Futhark/TypeChecker/Rank.hs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index cf774947ff..7582b27157 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -238,7 +238,6 @@ addRankInfo t = do new_vars <- gets substNewVars maybe new_var (const $ pure ()) $ new_vars M.!? t where - lvl = 0 -- FIXME new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars From 3b125122a42b42446a6da976358aadd2cb30f029 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 14:52:56 +0100 Subject: [PATCH 044/296] Has to be this way around for dumb reasons. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 96eb651bb3..bdde3a119d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -525,7 +525,7 @@ checkExp (QualParens (modname, modnameloc) e loc) = do "Module" <+> pretty modname <+> " is a parametric module." checkExp (Var qn (Info t) loc) = do t' <- lookupVar loc qn - unify (mkUsage loc "inferred rank type") t t' + unify (mkUsage loc "inferred rank type") t' t pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg From 36ae5c9e9c0a13242fea27e2d673f84a7e229a3c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 14:55:45 +0100 Subject: [PATCH 045/296] Always connect rank type. --- src/Language/Futhark/TypeChecker/Terms.hs | 19 +++--- .../Futhark/TypeChecker/Terms/Monad.hs | 61 ++++++++++--------- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index bdde3a119d..816f334069 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -469,8 +469,8 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (BinOp (op, oploc) _ (e1, _) (e2, _) loc) _) = do - ftype <- lookupVar oploc op +checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do + ftype <- lookupVar oploc op op_t e1' <- checkExp e1 e2' <- checkExp e2 @@ -524,8 +524,7 @@ checkExp (QualParens (modname, modnameloc) e loc) = do typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." checkExp (Var qn (Info t) loc) = do - t' <- lookupVar loc qn - unify (mkUsage loc "inferred rank type") t' t + t' <- lookupVar loc qn t pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg @@ -715,11 +714,11 @@ checkExp (Lambda params body rettype_te _ loc) = do onDim _ = mempty pure $ RetType (S.toList $ foldMap onDim $ fvVars $ freeInType ret) ret -checkExp (OpSection op _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSection op (Info op_t) loc) = do + ftype <- lookupVar loc op op_t pure $ OpSection op (Info ftype) loc -checkExp (OpSectionLeft op _ e _ _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do + ftype <- lookupVar loc op op_t e' <- checkExp e (t1, rt, argext, retext) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of @@ -735,8 +734,8 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op _ e _ _ loc) = do - ftype <- lookupVar loc op +checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do + ftype <- lookupVar loc op op_t e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index e8525e0dfd..f5ca46c85c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -448,38 +448,41 @@ instance MonadTypeChecker TermTypeM where Nothing -> throwError $ TypeError (locOf loc) notes s -lookupVar :: SrcLoc -> QualName VName -> TermTypeM StructType -lookupVar loc qn@(QualName qs name) = do +lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType +lookupVar loc qn@(QualName qs name) t = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams t) -> do - when (null qs) . modify $ \s -> - s {stateUsed = S.insert name $ stateUsed s} - if null tparams && null qs - then pure t - else do - (tnames, t') <- instantiateTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + t' <- + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams bound_t) -> do + when (null qs) . modify $ \s -> + s {stateUsed = S.insert name $ stateUsed s} + if null tparams && null qs + then pure t + else do + (tnames, t') <- instantiateTypeScheme qn loc tparams bound_t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newTypeVar loc "t" + equalityType usage argtype + pure $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ + Scalar $ + Arrow mempty Unnamed Observe argtype $ + RetType [] $ + Scalar $ + Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeVar loc "t" + mustBeOneOf ts usage argtype + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + unify (mkUsage loc "inferred rank type") t' t + pure t' where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, From 7d4bad6a7de2f05d3c6843ac1e1d78e5981b4b8d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 16:16:33 +0100 Subject: [PATCH 046/296] Various fixes. --- .../Futhark/TypeChecker/Constraints.hs | 1 + src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 67131e67cc..1886be827c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -20,6 +20,7 @@ import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T +import Debug.Trace import Futhark.IR.Pretty import Futhark.Util.Pretty import Language.Futhark diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 4fa481140c..0d7ec4e01d 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -342,8 +342,7 @@ instance MonadTypeChecker TermM where --- All the general machinery goes above. arrayOfRank :: Int -> Type -> Type -arrayOfRank 0 t = t -arrayOfRank n t = arrayOf (Shape $ replicate n SDim) t +arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why pts e = do @@ -449,17 +448,16 @@ checkPat' (Wildcard _ loc) (Ascribed t) = checkPat' (Wildcard NoInfo loc) NoneInferred = do t <- newType loc "t" pure $ Wildcard (Info t) loc -checkPat' p@(TuplePat ps loc) (Ascribed t) +checkPat' (TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = TuplePat <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t :: [Type] <- replicateM (length ps) (newType loc "t") - ctEq (Scalar (tupleRecord ps_t)) (toType t) - st <- asStructType loc $ Scalar $ tupleRecord ps_t - checkPat' p $ Ascribed $ toParam Observe st + ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") + ctEq (toType (Scalar (tupleRecord ps_t))) (toType t) + TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) (Ascribed t) @@ -924,14 +922,16 @@ checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) (Info $ AppRes (typeOf body') []) -- checkExp (AppExp (Range start maybe_step end loc) _) = do - start' <- checkExp' start - maybe_step' <- traverse checkExp' maybe_step - end' <- traverse checkExp' end - range_t <- newType loc "range" - ctEq (toType range_t) (arrayOfRank 1 (expType start')) - pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] - where - checkExp' = require "use in range expression" anyIntType <=< checkExp + start' <- require "use in range expression" anyIntType =<< checkExp start + let check e = do + e' <- checkExp e + ctEq (expType start') (expType e') + pure e' + maybe_step' <- traverse check maybe_step + end' <- traverse check end + range_t <- newTyVar loc "range" + ctEq (tyVarType range_t :: Type) (arrayOfRank 1 (expType start')) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes (tyVarType range_t) [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e @@ -1074,7 +1074,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "# function " <> prettyNameString fname + traceM $ "\n# function " <> prettyNameString fname <> "\n" vns <- gets termNameSource From 6a958f0a558d49f4ce40ce0b0a84b98271b16c64 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 11:05:19 -0800 Subject: [PATCH 047/296] Add options to set all ranks to zero for debugging. --- src/Language/Futhark/TypeChecker/Rank.hs | 12 ++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 4 +++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 7582b27157..ed3c0e24d1 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -134,11 +134,15 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -rankAnalysis :: VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) -rankAnalysis vns counter cs tyVars = do +rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) +rankAnalysis debug_zero_ranks vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] - (_size, ranks) <- branchAndBound lp - let rank_map = (fromJust . (ranks V.!?)) <$> inv_var_map + rank_map <- + if debug_zero_ranks + then pure $ fmap (const 0) inv_var_map + else do + (_size, ranks) <- branchAndBound lp + pure $ (fromJust . (ranks V.!?)) <$> inv_var_map traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) let initEnv = SubstEnv diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 0d7ec4e01d..7021afd3d3 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1078,7 +1078,9 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do vns <- gets termNameSource - case rankAnalysis vns counter cts tyvars of + let debug_zero_ranks = True + + case rankAnalysis debug_zero_ranks vns counter cts tyvars of Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} From 989e1e30558f1bb038f3747638d3d71ae45c1b9a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 12:06:20 -0800 Subject: [PATCH 048/296] Use PuLP instead of setting ranks to zero. --- futhark.cabal | 2 + src/Futhark/Solve/LP.hs | 4 +- src/Language/Futhark/TypeChecker/Rank.hs | 44 +++++++++++++++++++--- src/Language/Futhark/TypeChecker/Terms2.hs | 10 ++++- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index ded1234442..c73b997f71 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -493,6 +493,8 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 + -- remove me later + , process executable futhark import: common diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index af9265f458..a2b625a5e0 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -172,7 +172,9 @@ linearProgToPulp prog = show (prettyName v) <> " = " <> "LpVariable(" - <> show (show (prettyName v)) + <> "'" + <> show (prettyName v) + <> "_'" <> ", lowBound = 0, cat = 'Integer')" ) prog_vars diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index ed3c0e24d1..83d7d5dce8 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -2,6 +2,7 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where import Control.Monad.Reader import Control.Monad.State +import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M import Data.Maybe @@ -16,6 +17,8 @@ import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad (mkTypeVarName) +import System.IO.Unsafe +import System.Process type LSum = LP.LSum VName Double @@ -135,11 +138,16 @@ mkLinearProg counter cs tyVars = finalState = flip execState initState $ runRankM buildLP rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) -rankAnalysis debug_zero_ranks vns counter cs tyVars = do +rankAnalysis use_python vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- - if debug_zero_ranks - then pure $ fmap (const 0) inv_var_map + if use_python + then do + -- traceM $ linearProgToPulp prog + parseRes $ + unsafePerformIO $ + readProcess "python" [] $ + linearProgToPulp prog else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map @@ -181,6 +189,29 @@ rankAnalysis debug_zero_ranks vns counter cs tyVars = do (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] + rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" + vname_to_pulp_var = M.mapWithKey (\k _ -> map rm_subscript $ show $ prettyName k) inv_var_map + pulp_var_to_vname = + M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList vname_to_pulp_var] + + parseRes :: String -> Maybe (Map VName Int) + parseRes s = do + (status : vars) <- trimToStart $ lines s + if not (success status) + then Nothing + else do + pure $ M.fromList $ catMaybes $ map readVar vars + where + trimToStart [] = Nothing + trimToStart (l : ls) + | "status" `L.isPrefixOf` l = Just (l : ls) + | otherwise = trimToStart ls + success l = + (read $ drop (length ("status: " :: [Char])) l) == (1 :: Int) + readVar xs = + let (v, _ : value) = L.span (/= ':') xs + in Just (fromJust $ pulp_var_to_vname M.!? v, read value) + newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) @@ -233,19 +264,20 @@ newTyVar t = do rankToShape :: VName -> SubstM (Shape SComp) rankToShape x = do rs <- asks envRanks - pure $ Shape $ replicate (rs M.! x) SDim + pure $ Shape $ replicate (fromJust $ rs M.!? x) SDim addRankInfo :: TyVar -> SubstM () addRankInfo t = do rs <- asks envRanks - unless (rs M.! t == 0) $ do + -- unless (fromMaybe (error $ prettyString t) (rs M.!? t) == 0) $ do + unless (fromMaybe 0 (rs M.!? t) == 0) $ do new_vars <- gets substNewVars maybe new_var (const $ pure ()) $ new_vars M.!? t where new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars - let info = old_tyvars M.! t + let info = fromJust $ old_tyvars M.!? t modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 7021afd3d3..8fa0bf765a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1078,9 +1078,15 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do vns <- gets termNameSource - let debug_zero_ranks = True + let use_python = True - case rankAnalysis debug_zero_ranks vns counter cts tyvars of + traceM $ + unlines + [ "## cts:", + unlines $ map prettyString cts + ] + + case rankAnalysis use_python vns counter cts tyvars of Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} From 3c2e319b1928f6acfb6162ef7f13d380c806a0bc Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 12 Feb 2024 12:16:47 -0800 Subject: [PATCH 049/296] Add PuLP stuff. --- shell.nix | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/shell.nix b/shell.nix index cf8424a2be..d5199b0c02 100644 --- a/shell.nix +++ b/shell.nix @@ -2,6 +2,20 @@ let sources = import ./nix/sources.nix; pkgs = import sources.nixpkgs {}; + pps = ps: with ps; [ + ( + buildPythonPackage rec { + pname = "PuLP"; + version = "2.7.0"; + src = fetchPypi { + inherit pname version; + sha256 = "sha256-5z7msy1jnJuM9LSt7TNLoVi+X4MTVE4Fb3lqzgoQrmM="; + }; + doCheck = false; + } + ) + ]; + python = pkgs.python3.withPackages pps; in pkgs.stdenv.mkDerivation { name = "futhark"; @@ -38,6 +52,10 @@ pkgs.stdenv.mkDerivation { python3Packages.sphinx python3Packages.sphinxcontrib-bibtex imagemagick # needed for literate tests + # remove (needed for PuLP) + python + cbc + glpk ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers From 5ef513d89d8f879fae436ec37dde09b215560318 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 21:45:32 +0100 Subject: [PATCH 050/296] Improve Solution. --- .../Futhark/TypeChecker/Constraints.hs | 21 +++++++------------ src/Language/Futhark/TypeChecker/Terms.hs | 13 +++++++----- src/Language/Futhark/TypeChecker/Terms2.hs | 3 ++- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1886be827c..b504d3d663 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -117,33 +117,28 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt --- | A solution maps types to the set of type variables that must be --- substituted with this type, as well as its binding level. This --- slightly odd representation is needed to encode when two type +-- | A solution maps a type variable to its substitution, binding +-- level, and additional type variables that are linked to this type. +-- This slightly odd representation is needed to encode when two type -- variables are actually the same type. This matters when we start -- instanting the sizes of the type. -type Solution = M.Map Type (Int, [TyVar]) +type Solution = M.Map TyVar (Type, Int, [TyVar]) solution :: SolverState -> Solution solution s = - L.foldl' byType mempty $ + L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ M.toList $ - L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ - M.toList $ - solverTyVars s + solverTyVars s where - mkSubst (TyVarSol lvl t) = Just (lvl, (t, [])) + mkSubst (TyVarSol lvl t) = Just (t, lvl, []) mkSubst _ = Nothing addLinks m (v1, TyVarLink v2) = case M.lookup v2 $ solverTyVars s of Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) _ -> case M.lookup v2 m of Nothing -> m - Just (t, (lvl, vs)) -> M.insert v2 (t, (lvl, v1 : vs)) m + Just (t, lvl, vs) -> M.insert v2 (t, lvl, v1 : vs) m addLinks m _ = m - byType m (v, (lvl, (t, vs))) = M.insertWith comb t (lvl, v : vs) m - where - comb (lvl1, ts1) (lvl2, ts2) = (min lvl1 lvl2, ts1 <> ts2) newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 816f334069..65c82ae0fd 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1626,15 +1626,15 @@ arrayOfM loc t shape = do atLevel :: Int -> TermTypeM a -> TermTypeM a atLevel lvl = local $ \env -> env {termLevel = lvl} -addInitialConstraints :: M.Map (TypeBase () NoUniqueness) (Int, [VName]) -> TermTypeM () +addInitialConstraints :: Terms2.Solution -> TermTypeM () addInitialConstraints = mapM_ f . M.toList where addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) usage = mkUsage (mempty :: Loc) - f (t, (lvl, vs)) = do + f (v, (t, lvl, vs)) = do (t', _) <- atLevel lvl $ allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t - forM_ vs $ \v -> - addConstraint v lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v + forM_ (v : vs) $ \v' -> + addConstraint v' lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v' -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant @@ -1660,7 +1660,10 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right tysubsts -> runTermTypeM checkExp $ do - addInitialConstraints $ M.mapKeys (first $ const ()) tysubsts + addInitialConstraints tysubsts + + traceM $ unlines $ map prettyString params + traceM $ prettyString body' (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 8fa0bf765a..d6e95785a7 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -37,6 +37,7 @@ -- inference, perhaps we can do it in a post-inference check. module Language.Futhark.TypeChecker.Terms2 ( checkValDef, + Solution, ) where @@ -1099,7 +1100,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", - let p (t, (lvl, vs)) = unwords (show [lvl] : map prettyNameString vs) <> " => " <> prettyString t + let p (v, (t, lvl, vs)) = unwords (show [lvl] : map prettyNameString (v : vs)) <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') From 678633bf18b56008b684695e2ac989d2930f4689 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 12 Feb 2024 22:00:41 +0100 Subject: [PATCH 051/296] Preserve types better. --- src/Language/Futhark/TypeChecker/Terms.hs | 1 - src/Language/Futhark/TypeChecker/Terms/Pat.hs | 19 ++++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 65c82ae0fd..a5073ce194 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1662,7 +1662,6 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do Right tysubsts -> runTermTypeM checkExp $ do addInitialConstraints tysubsts - traceM $ unlines $ map prettyString params traceM $ prettyString body' (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 16ad00f710..3c80a3b5ca 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -115,9 +115,9 @@ bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - Pat ParamType -> + Pat (TypeBase Size u) -> Inferred ParamType -> - TermTypeM (Pat ParamType) + TermTypeM (Pat (TypeBase Size u)) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = @@ -151,11 +151,6 @@ checkPat' sizes p@(RecordPat p_fs loc) (Ascribed t) RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do p_fs' <- traverse (const $ newTypeVar loc "t") $ M.fromList p_fs - - when (sort (M.keys p_fs') /= sort (map fst p_fs)) $ - typeError loc mempty $ - "Duplicate fields in record pattern" <+> pretty p <> "." - unify (mkUsage loc "matching a record pattern") (Scalar (Record p_fs')) (toStruct t) checkPat' sizes p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') where @@ -193,12 +188,12 @@ checkPat :: [(SizeBinder VName, QualName VName)] -> Pat (TypeBase Size u) -> Inferred StructType -> - (Pat ParamType -> TermTypeM a) -> + (Pat (TypeBase Size u) -> TermTypeM a) -> TermTypeM a checkPat sizes p t m = do p' <- onFailure (CheckingPat (fmap toStruct p) t) $ - checkPat' sizes (fmap (toParam Observe) p) (fmap (toParam Observe) t) + checkPat' sizes p (fmap (toParam Observe) t) let explicit = mustBeExplicitInType $ patternStructType p' @@ -216,7 +211,7 @@ bindingPat :: [SizeBinder VName] -> Pat (TypeBase Size u) -> StructType -> - (Pat ParamType -> TermTypeM a) -> + (Pat (TypeBase Size u) -> TermTypeM a) -> TermTypeM a bindingPat sizes p t m = do substs <- mapM mkSizeSubst sizes @@ -240,7 +235,9 @@ bindingParams :: bindingParams tps orig_ps m = bindingTypeParams tps $ do let descend ps' (p : ps) = checkPat [] p NoneInferred $ \p' -> - binding (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps + binding (patIdents $ fmap toStruct p') $ + incLevel $ + descend (p' : ps') ps descend ps' [] = m $ reverse ps' incLevel $ descend [] orig_ps From c73e681ed03be44ee566f87750b2e4b45ae351db Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 14 Feb 2024 17:33:58 +0100 Subject: [PATCH 052/296] WIP on type checker integration. --- src/Language/Futhark/TypeChecker/Terms.hs | 26 +-- .../Futhark/TypeChecker/Terms/Monad.hs | 150 ++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 3 files changed, 91 insertions(+), 87 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index a5073ce194..abf797470c 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1009,7 +1009,7 @@ checkApply loc (fname, prev_applied) ftype argexp = do -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp $ do +checkOneExp e = runTermTypeM checkExp mempty $ do e' <- checkExp $ undefined e let t = typeOf e' (tparams, _, _) <- @@ -1023,7 +1023,7 @@ checkOneExp e = runTermTypeM checkExp $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp $ do +checkSizeExp e = runTermTypeM checkExp mempty $ do e' <- checkExp $ undefined e let t = typeOf e' when (hasBinding e') $ @@ -1621,21 +1621,6 @@ arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t --- A hack we need to create size variables for types at the right --- level. -atLevel :: Int -> TermTypeM a -> TermTypeM a -atLevel lvl = local $ \env -> env {termLevel = lvl} - -addInitialConstraints :: Terms2.Solution -> TermTypeM () -addInitialConstraints = mapM_ f . M.toList - where - addConstraint v lvl c = modifyConstraints $ M.insert v (lvl, c) - usage = mkUsage (mempty :: Loc) - f (v, (t, lvl, vs)) = do - (t', _) <- atLevel lvl $ allDimsFreshInType (usage (prettyText t)) Nonrigid "dv" t - forM_ (v : vs) $ \v' -> - addConstraint v' lvl $ Constraint (RetType [] t') $ usage $ prettyNameText v' - -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant -- definitions, by treating them as 0-ary functions. @@ -1657,11 +1642,12 @@ checkFunDef :: checkFunDef (fname, retdecl, tparams, params, body, loc) = do (maybe_tysubsts, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + let adjust = M.fromList . concatMap f . M.toList + where + f (v, (t, _, vs)) = map (,first (const ()) t) (v : vs) case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp $ do - addInitialConstraints tysubsts - + Right tysubsts -> runTermTypeM checkExp (adjust tysubsts) $ do traceM $ prettyString body' (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index f5ca46c85c..b40e377530 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -50,17 +50,20 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State.Strict +import Data.Bifunctor import Data.Bitraversable -import Data.Char (isAscii) +import Data.Foldable import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T +import Debug.Trace import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals +import Language.Futhark.TypeChecker.Constraints (TyVar) import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types @@ -198,6 +201,7 @@ data TermEnv = TermEnv termLevel :: Level, termCheckExp :: ExpBase Info VName -> TermTypeM Exp, termOuterEnv :: Env, + termTyVars :: M.Map TyVar (TypeBase () NoUniqueness), termImportName :: ImportName } @@ -347,43 +351,59 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." --- | Create a new type name and insert it (unconstrained) in the --- substitution map. -instantiateTypeParam :: - (Monoid as) => - QualName VName -> - SrcLoc -> - TypeParam -> - TermTypeM (VName, Subst (RetTypeBase dim as)) -instantiateTypeParam qn loc tparam = do - i <- incCounter - let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) - v <- newID $ mkTypeVarName name i - case tparam of - TypeParamType x _ _ -> do - constrain v . NoConstraint x . mkUsage loc . docText $ - "instantiated type parameter of " <> dquotes (pretty qn) - pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) - TypeParamDim {} -> do - constrain v . Size Nothing . mkUsage loc . docText $ - "instantiated size parameter of " <> dquotes (pretty qn) - pure (v, ExpSubst $ sizeFromName (qualName v) loc) - --- | Instantiate a type scheme with fresh type variables for its type --- parameters. Returns the names of the fresh type variables, the --- instance list, and the instantiated type. -instantiateTypeScheme :: +replaceTyVars :: SrcLoc -> TypeBase () NoUniqueness -> StructType -> TermTypeM StructType +replaceTyVars loc orig_t1 orig_t2 = do + tyvars <- asks termTyVars + let f :: (Monoid u) => TypeBase () u' -> TypeBase Size u -> TermTypeM (TypeBase Size u) + f + (Scalar (TypeVar _ (QualName [] v1) [])) + t2 + | Just t <- M.lookup v1 tyvars = + f t t2 + | otherwise = + pure $ Scalar (TypeVar (fold t2) (QualName [] v1) []) + f (Scalar (Record fs1)) (Scalar (Record fs2)) = + Scalar . Record <$> sequence (M.intersectionWith f fs1 fs2) + f (Scalar (Sum fs1)) (Scalar (Sum fs2)) = + Scalar . Sum <$> sequence (M.intersectionWith (zipWithM f) fs1 fs2) + f + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow u pname d t2a (RetType ext t2r))) = do + ta <- f t1a t2a + tr <- f t1r t2r + pure $ Scalar $ Arrow u pname d ta $ RetType ext tr + f + (Array _ (Shape (() : ds1)) t1) + (Array u (Shape (d : ds2)) t2) = + arrayOfWithAliases u (Shape [d]) + <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) + f _ t2 = pure t2 + f orig_t1 orig_t2 + +-- | Instantiate a type scheme with fresh size variables for its size +-- parameters. Replaces type parameters with their known +-- instantiations. Returns the names of the fresh size variables and +-- the instantiated type. +instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> StructType -> + TypeBase () NoUniqueness -> TermTypeM ([VName], StructType) -instantiateTypeScheme qn loc tparams t = do - let tnames = map typeParamName tparams - (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams - let substs = M.fromList $ zip tnames tparam_substs - t' = applySubst (`M.lookup` substs) t - pure (tparam_names, t') +instTypeScheme qn loc tparams scheme_t inferred = do + (names, substs) <- fmap (unzip . catMaybes) $ + forM tparams $ \tparam -> do + case tparam of + TypeParamType {} -> pure Nothing + TypeParamDim v _ -> do + constrain v . Size Nothing . mkUsage loc . docText $ + "instantiated size parameter of " <> dquotes (pretty qn) + pure $ Just (v, (v, ExpSubst $ sizeFromName (qualName v) loc)) + + t' <- replaceTyVars loc inferred $ applySubst (`lookup` substs) scheme_t + + pure (names, t') lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) @@ -453,36 +473,33 @@ lookupVar loc qn@(QualName qs name) t = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - t' <- - case M.lookup name $ scopeVtable scope of - Nothing -> - error $ "lookupVar: " <> show qn - Just (BoundV tparams bound_t) -> do - when (null qs) . modify $ \s -> - s {stateUsed = S.insert name $ stateUsed s} - if null tparams && null qs - then pure t - else do - (tnames, t') <- instantiateTypeScheme qn loc tparams bound_t - outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' - unify (mkUsage loc "inferred rank type") t' t - pure t' + case M.lookup name $ scopeVtable scope of + Nothing -> + error $ "lookupVar: " <> show qn + Just (BoundV tparams bound_t) -> do + when (null qs) . modify $ \s -> + s {stateUsed = S.insert name $ stateUsed s} + if null tparams && null qs + then pure t + else do + (tnames, t') <- instTypeScheme qn loc tparams bound_t $ first (const ()) t + outer_env <- asks termOuterEnv + pure $ qualifyTypeVars outer_env tnames qs t' + Just EqualityF -> do + argtype <- newTypeVar loc "t" + equalityType usage argtype + pure $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ + Scalar $ + Arrow mempty Unnamed Observe argtype $ + RetType [] $ + Scalar $ + Prim Bool + Just (OverloadedF ts pts rt) -> do + argtype <- newTypeVar loc "t" + mustBeOneOf ts usage argtype + let (pts', rt') = instOverloaded argtype pts rt + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, @@ -639,8 +656,8 @@ initialTermScope = Just (name, EqualityF) addIntrinsicF _ = Nothing -runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a -runTermTypeM checker (TermTypeM m) = do +runTermTypeM :: (ExpBase Info VName -> TermTypeM Exp) -> M.Map TyVar (TypeBase () NoUniqueness) -> TermTypeM a -> TypeM a +runTermTypeM checker tyvars (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv name <- askImportName outer_env <- askEnv @@ -652,7 +669,8 @@ runTermTypeM checker (TermTypeM m) = do termLevel = 0, termCheckExp = checker, termImportName = name, - termOuterEnv = outer_env + termOuterEnv = outer_env, + termTyVars = tyvars } initial_state = TermTypeState diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d6e95785a7..804fcf9c48 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -362,7 +362,7 @@ instTypeScheme :: TermM ([VName], StructType) instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ - forM tparams $ \tparam -> do + forM tparams $ \tparam -> case tparam of TypeParamType _ v _ -> do v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v From fb42c3245cc122851ca279f739e57584138dfbb3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 Feb 2024 11:24:16 +0100 Subject: [PATCH 053/296] Now we get quite far in type checking! --- .../Futhark/TypeChecker/Constraints.hs | 28 +++----- src/Language/Futhark/TypeChecker/Rank.hs | 1 + src/Language/Futhark/TypeChecker/Terms.hs | 72 +++++++++---------- .../Futhark/TypeChecker/Terms/Monad.hs | 35 +++++---- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 +++- src/Language/Futhark/TypeChecker/Terms2.hs | 15 +++- 6 files changed, 89 insertions(+), 74 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b504d3d663..b6920c0f5b 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -106,7 +106,7 @@ substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = case M.lookup v m of Just (TyVarLink v') -> substTyVars m $ Scalar $ TypeVar u (QualName qs v') args - Just (TyVarSol _ t') -> second (const mempty) t' + Just (TyVarSol _ t') -> second (const mempty) $ substTyVars m t' Just (TyVarUnsol {}) -> t Nothing -> t substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt @@ -117,28 +117,20 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt --- | A solution maps a type variable to its substitution, binding --- level, and additional type variables that are linked to this type. --- This slightly odd representation is needed to encode when two type --- variables are actually the same type. This matters when we start --- instanting the sizes of the type. -type Solution = M.Map TyVar (Type, Int, [TyVar]) +-- | A solution maps a type variable to its substitution. This substitution is complete, in the sense there are no right-hand sides that contain a type variable. +type Solution = M.Map TyVar (TypeBase () NoUniqueness) solution :: SolverState -> Solution solution s = - L.foldl' addLinks (M.mapMaybe mkSubst $ solverTyVars s) $ - M.toList $ - solverTyVars s + M.mapMaybe mkSubst $ + solverTyVars s where - mkSubst (TyVarSol lvl t) = Just (t, lvl, []) + mkSubst (TyVarSol _lvl t) = Just $ first (const ()) $ substTyVars (solverTyVars s) t + mkSubst (TyVarLink v') = mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (TyVarUnsol _ (TyVarPrim pts)) + | Signed Int32 `elem` pts = + Just (Scalar (Prim (Signed Int32))) -- XXX - we need warnings and things! mkSubst _ = Nothing - addLinks m (v1, TyVarLink v2) = - case M.lookup v2 $ solverTyVars s of - Just (TyVarLink v3) -> addLinks m (v1, TyVarLink v3) - _ -> case M.lookup v2 m of - Nothing -> m - Just (t, lvl, vs) -> M.insert v2 (t, lvl, v1 : vs) m - addLinks m _ = m newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 83d7d5dce8..56d1be4e56 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -138,6 +138,7 @@ mkLinearProg counter cs tyVars = finalState = flip execState initState $ runRankM buildLP rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) +rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_python vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index abf797470c..01b8cfdc3d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -352,14 +352,17 @@ unscopeType tloc unscoped = checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc -checkExp (Hole info loc) = - pure $ Hole info loc +checkExp (Hole (Info t) loc) = do + t' <- replaceTyVars loc t t + pure $ Hole (Info t') loc checkExp (StringLit vs loc) = pure $ StringLit vs loc -checkExp (IntLit val info loc) = - pure $ IntLit val info loc -checkExp (FloatLit val info loc) = - pure $ FloatLit val info loc +checkExp (IntLit val (Info t) loc) = do + t' <- replaceTyVars loc t t + pure $ IntLit val (Info t') loc +checkExp (FloatLit val (Info t) loc) = do + t' <- replaceTyVars loc t t + pure $ FloatLit val (Info t') loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc checkExp (RecordLit fs loc) = @@ -662,14 +665,14 @@ checkExp (Assert e1 e2 _ loc) = do e1' <- require "being asserted" [Bool] =<< checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc -checkExp (Lambda params body rettype_te _ loc) = do +checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do rettype_checked <- traverse checkTypeExpNonrigid rettype_te - let declared_rettype = - case rettype_checked of - Just (_, st, _) -> Just st - Nothing -> Nothing + declared_rettype <- + case rettype_checked of + Just (_, st, _) -> Just <$> replaceTyVars loc rt st + Nothing -> pure Nothing body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' @@ -677,11 +680,13 @@ checkExp (Lambda params body rettype_te _ loc) = do (rettype', rettype_st) <- case rettype_checked of - Just (te, st, ext) -> - pure (Just te, RetType ext st) + Just (te, st, ext) -> do + st' <- replaceTyVars loc rt st + pure (Just te, RetType ext st') Nothing -> do - ret <- inferReturnSizes params'' $ toRes Nonunique body_t - pure (Nothing, ret) + RetType ext ret <- inferReturnSizes params'' $ toRes Nonunique body_t + ret' <- replaceTyVars loc rt ret + pure (Nothing, RetType ext ret') pure (params'', body', rettype', rettype_st) @@ -851,14 +856,10 @@ instance Pretty (Unmatched (Pat StructType)) where checkSlice :: SliceBase Info VName -> TermTypeM [DimIndex] checkSlice = mapM checkDimIndex where - checkDimIndex (DimFix i) = do - DimFix <$> (require "use as index" anySignedType =<< checkExp i) + checkDimIndex (DimFix i) = + DimFix <$> checkExp i checkDimIndex (DimSlice i j s) = - DimSlice <$> check i <*> check j <*> check s - - check = - maybe (pure Nothing) $ - fmap Just . unifies "use as index" (Scalar $ Prim $ Signed Int64) <=< checkExp + DimSlice <$> traverse checkExp i <*> traverse checkExp j <*> traverse checkExp s -- The number of dimensions affected by this slice (so the minimum -- rank of the array we are slicing). @@ -1023,14 +1024,18 @@ checkOneExp e = runTermTypeM checkExp mempty $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp -checkSizeExp e = runTermTypeM checkExp mempty $ do - e' <- checkExp $ undefined e - let t = typeOf e' - when (hasBinding e') $ - typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ - "Size expression with binding is forbidden." - unify (mkUsage e' "Size expression") t (Scalar (Prim (Signed Int64))) - updateTypes e' +checkSizeExp e = do + (maybe_tysubsts, e') <- Terms2.checkSingleExp e + case maybe_tysubsts of + Left err -> typeError e' mempty $ pretty err + Right tysubsts -> runTermTypeM checkExp tysubsts $ do + e'' <- checkExp e' + let t = typeOf e'' + when (hasBinding e'') $ + typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ + "Size expression with binding is forbidden." + unify (mkUsage e'' "Size expression") t (Scalar (Prim (Signed Int64))) + updateTypes e'' -- Verify that all sum type constructors and empty array literals have -- a size that is known (rigid or a type parameter). This is to @@ -1642,14 +1647,9 @@ checkFunDef :: checkFunDef (fname, retdecl, tparams, params, body, loc) = do (maybe_tysubsts, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - let adjust = M.fromList . concatMap f . M.toList - where - f (v, (t, _, vs)) = map (,first (const ()) t) (v : vs) case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp (adjust tysubsts) $ do - traceM $ prettyString body' - + Right tysubsts -> runTermTypeM checkExp tysubsts $ do (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index b40e377530..2b7e2cf9fd 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -25,6 +25,7 @@ module Language.Futhark.TypeChecker.Terms.Monad constrain, newArrayType, allDimsFreshInType, + replaceTyVars, updateTypes, Names, @@ -351,33 +352,37 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." -replaceTyVars :: SrcLoc -> TypeBase () NoUniqueness -> StructType -> TermTypeM StructType +replaceTyVars :: + SrcLoc -> + TypeBase d u1 -> + TypeBase Size u2 -> + TermTypeM (TypeBase Size u1) replaceTyVars loc orig_t1 orig_t2 = do tyvars <- asks termTyVars - let f :: (Monoid u) => TypeBase () u' -> TypeBase Size u -> TermTypeM (TypeBase Size u) + let f :: TypeBase d u1 -> TypeBase Size u2 -> TermTypeM (TypeBase Size u1) f - (Scalar (TypeVar _ (QualName [] v1) [])) + (Scalar (TypeVar u (QualName [] v1) [])) t2 | Just t <- M.lookup v1 tyvars = - f t t2 - | otherwise = - pure $ Scalar (TypeVar (fold t2) (QualName [] v1) []) + f (second (const u) t) t2 f (Scalar (Record fs1)) (Scalar (Record fs2)) = Scalar . Record <$> sequence (M.intersectionWith f fs1 fs2) f (Scalar (Sum fs1)) (Scalar (Sum fs2)) = Scalar . Sum <$> sequence (M.intersectionWith (zipWithM f) fs1 fs2) f - (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) - (Scalar (Arrow u pname d t2a (RetType ext t2r))) = do + (Scalar (Arrow u _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ pname d t2a (RetType ext t2r))) = do ta <- f t1a t2a tr <- f t1r t2r pure $ Scalar $ Arrow u pname d ta $ RetType ext tr f - (Array _ (Shape (() : ds1)) t1) - (Array u (Shape (d : ds2)) t2) = + (Array u (Shape (_ : ds1)) t1) + (Array _ (Shape (d : ds2)) t2) = arrayOfWithAliases u (Shape [d]) <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) - f _ t2 = pure t2 + f t1 _ = + fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1 + f orig_t1 orig_t2 -- | Instantiate a type scheme with fresh size variables for its size @@ -469,7 +474,7 @@ instance MonadTypeChecker TermTypeM where throwError $ TypeError (locOf loc) notes s lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType -lookupVar loc qn@(QualName qs name) t = do +lookupVar loc qn@(QualName qs name) inst_t = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) @@ -480,11 +485,11 @@ lookupVar loc qn@(QualName qs name) t = do when (null qs) . modify $ \s -> s {stateUsed = S.insert name $ stateUsed s} if null tparams && null qs - then pure t + then pure bound_t else do - (tnames, t') <- instTypeScheme qn loc tparams bound_t $ first (const ()) t + (tnames, t) <- instTypeScheme qn loc tparams bound_t $ first (const ()) inst_t outer_env <- asks termOuterEnv - pure $ qualifyTypeVars outer_env tnames qs t' + pure $ qualifyTypeVars outer_env tnames qs t Just EqualityF -> do argtype <- newTypeVar loc "t" equalityType usage argtype diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 3c80a3b5ca..f8bd0a42f6 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -122,9 +122,17 @@ checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc -checkPat' _ (Id name (Info t) loc) _ = +checkPat' _ (Id name (Info t) loc) NoneInferred = do + t' <- replaceTyVars loc (first (const ()) t) t + pure $ Id name (Info t') loc +checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do + t <- replaceTyVars loc (first (const ()) t1) t2 pure $ Id name (Info t) loc -checkPat' _ (Wildcard (Info t) loc) _ = +checkPat' _ (Wildcard (Info t) loc) NoneInferred = do + t' <- replaceTyVars loc (first (const ()) t) t + pure $ Wildcard (Info t') loc +checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do + t <- replaceTyVars loc (first (const ()) t1) t2 pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 804fcf9c48..e4ed9daef9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -37,6 +37,7 @@ -- inference, perhaps we can do it in a post-inference check. module Language.Futhark.TypeChecker.Terms2 ( checkValDef, + checkSingleExp, Solution, ) where @@ -1063,7 +1064,7 @@ checkValDef :: Maybe (TypeExp Exp VName), Exp ) -checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do +checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bindParams tparams params $ \params' -> do body' <- checkExp body @@ -1075,7 +1076,7 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "\n# function " <> prettyNameString fname <> "\n" + traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" vns <- gets termNameSource @@ -1100,7 +1101,15 @@ checkValDef (fname, retdecl, tparams, params, body, _loc) = runTermM $ do "## tyvars:", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", - let p (v, (t, lvl, vs)) = unwords (show [lvl] : map prettyNameString (v : vs)) <> " => " <> prettyString t + let p (v, t) = prettyNameString v <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] pure (solution, params', retdecl', body') + +checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either T.Text Solution, Exp) +checkSingleExp e = runTermM $ do + e' <- checkExp e + cts <- gets termConstraints + tyvars <- gets termTyVars + let solution = solve cts tyvars + pure (solution, e') From 333ca8de3e2bbf7d8fc0ac39c33e19457ddcbf53 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 Feb 2024 15:00:17 +0100 Subject: [PATCH 054/296] Instantiate sizes properly. --- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 2b7e2cf9fd..0b1efd86b7 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -397,14 +397,15 @@ instTypeScheme :: TypeBase () NoUniqueness -> TermTypeM ([VName], StructType) instTypeScheme qn loc tparams scheme_t inferred = do - (names, substs) <- fmap (unzip . catMaybes) $ - forM tparams $ \tparam -> do - case tparam of - TypeParamType {} -> pure Nothing - TypeParamDim v _ -> do - constrain v . Size Nothing . mkUsage loc . docText $ - "instantiated size parameter of " <> dquotes (pretty qn) - pure $ Just (v, (v, ExpSubst $ sizeFromName (qualName v) loc)) + (names, substs) <- fmap (unzip . catMaybes) . forM tparams $ \tparam -> do + case tparam of + TypeParamType {} -> pure Nothing + TypeParamDim v _ -> do + i <- incCounter + v' <- newID $ mkTypeVarName (baseName v) i + constrain v' . Size Nothing . mkUsage loc . docText $ + "instantiated size parameter of " <> dquotes (pretty qn) + pure $ Just (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) t' <- replaceTyVars loc inferred $ applySubst (`lookup` substs) scheme_t From b4983c63cfccb7ab59ff57be39f6b8c922d58bb1 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 06:08:04 -0800 Subject: [PATCH 055/296] Don't normalize/forget variables. --- src/Futhark/Solve/LP.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index a2b625a5e0..c3e321dfae 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -239,7 +239,8 @@ cval :: (Num a, Ord v) => LSum v a -> a cval = (! Nothing) (~+~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a -(LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y +-- (LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y +(LSum x) ~+~ (LSum y) = LSum $ Map.unionWith (+) x y infixl 6 ~+~ @@ -249,7 +250,8 @@ x ~-~ y = x ~+~ (neg y) infixl 6 ~-~ (~*~) :: (Eq a, Num a, Ord v) => a -> LSum v a -> LSum v a -a ~*~ s = normalize $ fmap (a *) s +-- a ~*~ s = normalize $ fmap (a *) s +a ~*~ s = fmap (a *) s infixl 7 ~*~ From b122b274b0ea47ff376ee064413373ea0e5fee2a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 Feb 2024 15:11:24 +0100 Subject: [PATCH 056/296] Preserve uniqueness. --- src/Language/Futhark/TypeChecker/Constraints.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b6920c0f5b..665e67af3f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -113,7 +113,7 @@ substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs substTyVars m (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars m) cs substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = - Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 + Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 `setUniqueness` uniqueness t2 substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt From fc083add39863f5a3631e9a348f71432dfedb86f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 07:15:35 -0800 Subject: [PATCH 057/296] Jank frame fix. --- src/Language/Futhark/TypeChecker/Terms2.hs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e4ed9daef9..8005ffbcc8 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -596,7 +596,7 @@ bindParams tps orig_ps m = bindTypeParams tps $ do checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) checkApply loc _ ftype fframe arg = do - (a, b) <- split ftype + (a, b) <- split $ stripFrame fframe ftype r <- newSVar loc "R" m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool @@ -607,10 +607,17 @@ checkApply loc _ ftype fframe arg = do ctAM r m ctEq lhs rhs pure - ( b, + ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where + stripFrame :: Shape Size -> Type -> Type + stripFrame frame (Array u ds t) = + let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) + in case mnew_shape of + Nothing -> Scalar t + Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t + stripFrame _ t = t toSComp (Var (QualName [] x) _ _) = SVar x toSComp _ = error "" toShape = Shape . pure From 1ed744b04674eb5a70f06c52b7fb03a9774badb8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 09:50:04 -0800 Subject: [PATCH 058/296] SPEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEED --- futhark.cabal | 2 ++ src/Futhark/Solve/LP.hs | 19 ++++++++++++++ src/Language/Futhark/TypeChecker/Rank.hs | 30 +++------------------- src/Language/Futhark/TypeChecker/Terms2.hs | 4 +-- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index c73b997f71..8cf87483f3 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -371,6 +371,7 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script + Futhark.Solve.GLPK Futhark.Solve.LP Futhark.Solve.Matrix Futhark.Solve.Simplex @@ -495,6 +496,7 @@ library , prettyprinter-ansi-terminal >= 1.1 -- remove me later , process + , glpk-hs executable futhark import: common diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index c3e321dfae..7623033e7c 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -20,6 +20,8 @@ module Futhark.Solve.LP LinearProg (..), OptType (..), Constraint (..), + Vars (..), + CType (..), (~==~), (~<=~), (~>=~), @@ -28,11 +30,14 @@ module Futhark.Solve.LP ) where +import Control.Monad.LPMonad import Data.Char (isAscii) import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as Map import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V import Debug.Trace @@ -106,6 +111,12 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where instance Functor (LSum v) where fmap f (LSum m) = LSum $ fmap f m +class Vars a v where + vars :: a -> Set v + +instance (Ord v) => Vars (LSum v a) v where + vars = S.fromList . catMaybes . Map.keys . lsum + -- | Type of constraint data CType = Equal | LessEq deriving (Show, Eq) @@ -123,6 +134,9 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where pretty (Constraint t l r) = pretty l <+> pretty t <+> pretty r +instance (Ord v) => Vars (Constraint v a) v where + vars (Constraint _ l r) = vars l <> vars r + data OptType = Maximize | Minimize deriving (Show, Eq) @@ -147,6 +161,11 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where indent 2 $ vcat $ map pretty cs ] +instance (Ord v) => Vars (LinearProg v a) v where + vars lp = + vars (objective lp) + <> foldMap vars (constraints lp) + -- For debugging linearProgToPulp :: (Unbox a, IsName v, Ord v, Pretty a, Eq a, Num a) => LinearProg v a -> String linearProgToPulp prog = diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 56d1be4e56..d255f3106f 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -11,6 +11,7 @@ import Debug.Trace import Futhark.FreshNames qualified as FreshNames import Futhark.MonadFreshNames hiding (newName) import Futhark.Solve.BranchAndBound +import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Futhark.Solve.Simplex @@ -139,16 +140,11 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) -rankAnalysis use_python vns counter cs tyVars = do +rankAnalysis use_glpk vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- - if use_python - then do - -- traceM $ linearProgToPulp prog - parseRes $ - unsafePerformIO $ - readProcess "python" [] $ - linearProgToPulp prog + if use_glpk + then snd <$> (unsafePerformIO $ glpk prog) else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map @@ -195,24 +191,6 @@ rankAnalysis use_python vns counter cs tyVars = do pulp_var_to_vname = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList vname_to_pulp_var] - parseRes :: String -> Maybe (Map VName Int) - parseRes s = do - (status : vars) <- trimToStart $ lines s - if not (success status) - then Nothing - else do - pure $ M.fromList $ catMaybes $ map readVar vars - where - trimToStart [] = Nothing - trimToStart (l : ls) - | "status" `L.isPrefixOf` l = Just (l : ls) - | otherwise = trimToStart ls - success l = - (read $ drop (length ("status: " :: [Char])) l) == (1 :: Int) - readVar xs = - let (v, _ : value) = L.span (/= ':') xs - in Just (fromJust $ pulp_var_to_vname M.!? v, read value) - newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 8005ffbcc8..6fefb80673 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1087,7 +1087,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do vns <- gets termNameSource - let use_python = True + let use_glpk = True traceM $ unlines @@ -1095,7 +1095,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map prettyString cts ] - case rankAnalysis use_python vns counter cts tyvars of + case rankAnalysis use_glpk vns counter cts tyvars of Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} From 6eb745ef7ab1cd998b9f810b43b6d649ac6be086 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 15 Feb 2024 11:47:50 -0800 Subject: [PATCH 059/296] Add `GLPK.hs`. --- src/Futhark/Solve/GLPK.hs | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/Futhark/Solve/GLPK.hs diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs new file mode 100644 index 0000000000..7b27408a27 --- /dev/null +++ b/src/Futhark/Solve/GLPK.hs @@ -0,0 +1,47 @@ +module Futhark.Solve.GLPK (glpk) where + +import Data.LinearProgram +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Solve.LP qualified as F + +linearProgToGLPK :: (Show v, Ord v, Eq a, Num a, Group a) => F.LinearProg v a -> (LP v a) +linearProgToGLPK prog = + LP + { direction = cOptType $ F.optType prog, + objective = cObj $ F.objective prog, + constraints = map cConstraint $ F.constraints prog, + varBounds = bounds, + varTypes = kinds + } + where + cOptType F.Maximize = Max + cOptType F.Minimize = Min + cObj = fst . cLSum + + cLSum (F.LSum m) = + ( M.mapKeys fromJust $ M.filterWithKey (\k _ -> isJust k) m, + fromMaybe 0 (m M.!? Nothing) + ) + + cConstraint (F.Constraint ctype l r) = + let (linfunc, c) = cLSum $ l F.~-~ r + bound = + case ctype of + F.Equal -> Equ (-c) + F.LessEq -> UBound (-c) + in Constr Nothing linfunc bound + + bounds = M.fromList $ (,LBound 0) <$> varList + kinds = M.fromList $ (,IntVar) <$> varList + + varList = S.toList $ F.vars prog + +glpk :: + (Show v, Ord v, Show a, Eq a, Real a, Group a) => + F.LinearProg v a -> + IO (Maybe (Int, M.Map v Int)) +glpk lp = do + (_, mres) <- glpSolveVars mipDefaults $ linearProgToGLPK lp + pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres From 1d772818bf52914d46ab70006410f1c43a5ab957 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 16 Feb 2024 15:21:10 +0100 Subject: [PATCH 060/296] Basic things work now. --- src/Language/Futhark/TypeChecker/Rank.hs | 5 +- src/Language/Futhark/TypeChecker/Terms.hs | 19 ++--- .../Futhark/TypeChecker/Terms/Monad.hs | 70 +++++++++++++++---- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 10 +-- 4 files changed, 78 insertions(+), 26 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index d255f3106f..23d295f8ee 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -115,7 +115,10 @@ addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () addTyVarInfo tv (_, TyVarFree) = pure () addTyVarInfo tv (_, TyVarPrim _) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo _ _ = error "Unhandled" +addTyVarInfo tv (_, TyVarRecord _) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarSum _) = + addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: Int -> [Ct] -> TyVars -> LinearProg mkLinearProg counter cs tyVars = diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 01b8cfdc3d..919ebda406 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -353,15 +353,15 @@ checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc checkExp (Hole (Info t) loc) = do - t' <- replaceTyVars loc t t + t' <- replaceTyVars loc t pure $ Hole (Info t') loc checkExp (StringLit vs loc) = pure $ StringLit vs loc checkExp (IntLit val (Info t) loc) = do - t' <- replaceTyVars loc t t + t' <- replaceTyVars loc t pure $ IntLit val (Info t') loc checkExp (FloatLit val (Info t) loc) = do - t' <- replaceTyVars loc t t + t' <- replaceTyVars loc t pure $ FloatLit val (Info t') loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc @@ -668,10 +668,13 @@ checkExp (Assert e1 e2 _ loc) = do checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do + rt' <- replaceTyVars loc rt rettype_checked <- traverse checkTypeExpNonrigid rettype_te declared_rettype <- case rettype_checked of - Just (_, st, _) -> Just <$> replaceTyVars loc rt st + Just (_, st, _) -> do + unify (mkUsage body "lambda return type ascription") (toStruct rt') (toStruct st) + pure $ Just st Nothing -> pure Nothing body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' @@ -680,13 +683,11 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do (rettype', rettype_st) <- case rettype_checked of - Just (te, st, ext) -> do - st' <- replaceTyVars loc rt st - pure (Just te, RetType ext st') + Just (te, _, ext) -> + pure (Just te, RetType ext rt') Nothing -> do RetType ext ret <- inferReturnSizes params'' $ toRes Nonunique body_t - ret' <- replaceTyVars loc rt ret - pure (Nothing, RetType ext ret') + pure (Nothing, RetType ext ret) pure (params'', body', rettype', rettype_st) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 0b1efd86b7..87697bc359 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -58,7 +58,6 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) @@ -352,14 +351,47 @@ instance MonadUnify TermTypeM where indent 2 (pretty t2) "do not match." -replaceTyVars :: +replaceTyVars :: SrcLoc -> TypeBase Size u -> TermTypeM (TypeBase Size u) +replaceTyVars loc orig_t = do + tyvars <- asks termTyVars + let f :: TypeBase Size u -> TermTypeM (TypeBase Size u) + f (Scalar (Prim t)) = pure $ Scalar $ Prim t + f + (Scalar (TypeVar u (QualName [] v) [])) + | Just t <- M.lookup v tyvars = + fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" (second (const u) t) + | otherwise = + pure $ Scalar (TypeVar u (QualName [] v) []) + f (Scalar (TypeVar u qn targs)) = + Scalar . TypeVar u qn <$> mapM onTyArg targs + where + onTyArg (TypeArgDim e) = pure $ TypeArgDim e + onTyArg (TypeArgType t) = TypeArgType <$> f t + f (Scalar (Record fs)) = + Scalar . Record <$> traverse f fs + f (Scalar (Sum fs)) = + Scalar . Sum <$> traverse (mapM f) fs + f (Scalar (Arrow u pname d ta (RetType ext tr))) = do + ta' <- f ta + tr' <- f tr + pure $ Scalar $ Arrow u pname d ta' $ RetType ext tr' + f (Array u shape t) = + arrayOfWithAliases u shape <$> f (Scalar t) + + f orig_t + +instTyVars :: SrcLoc -> - TypeBase d u1 -> - TypeBase Size u2 -> - TermTypeM (TypeBase Size u1) -replaceTyVars loc orig_t1 orig_t2 = do + [VName] -> + TypeBase () u -> + TypeBase Size u -> + TermTypeM (TypeBase Size u) +instTyVars loc names orig_t1 orig_t2 = do tyvars <- asks termTyVars - let f :: TypeBase d u1 -> TypeBase Size u2 -> TermTypeM (TypeBase Size u1) + let f :: + TypeBase d u -> + TypeBase Size u -> + StateT (M.Map VName (TypeBase Size NoUniqueness)) TermTypeM (TypeBase Size u) f (Scalar (TypeVar u (QualName [] v1) [])) t2 @@ -380,10 +412,23 @@ replaceTyVars loc orig_t1 orig_t2 = do (Array _ (Shape (d : ds2)) t2) = arrayOfWithAliases u (Shape [d]) <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) - f t1 _ = - fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1 - - f orig_t1 orig_t2 + f t1 t2 = do + let mkNew = + fst <$> lift (allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1) + case t2 of + Scalar (TypeVar u (QualName [] v2) []) + | v2 `elem` names -> do + seen <- get + case M.lookup v2 seen of + Nothing -> do + t <- mkNew + modify $ M.insert v2 $ second (const NoUniqueness) t + pure t + Just t -> + pure $ second (const u) t + _ -> mkNew + + evalStateT (f orig_t1 orig_t2) mempty -- | Instantiate a type scheme with fresh size variables for its size -- parameters. Replaces type parameters with their known @@ -407,7 +452,8 @@ instTypeScheme qn loc tparams scheme_t inferred = do "instantiated size parameter of " <> dquotes (pretty qn) pure $ Just (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) - t' <- replaceTyVars loc inferred $ applySubst (`lookup` substs) scheme_t + let tp_names = map typeParamName $ filter isTypeParam tparams + t' <- instTyVars loc tp_names inferred $ applySubst (`lookup` substs) scheme_t pure (names, t') diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index f8bd0a42f6..b1a2f59a8d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -123,16 +123,18 @@ checkPat' sizes (PatParens p loc) t = checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc checkPat' _ (Id name (Info t) loc) NoneInferred = do - t' <- replaceTyVars loc (first (const ()) t) t + t' <- replaceTyVars loc t pure $ Id name (Info t') loc checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc (first (const ()) t1) t2 + t <- replaceTyVars loc t1 + unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) pure $ Id name (Info t) loc checkPat' _ (Wildcard (Info t) loc) NoneInferred = do - t' <- replaceTyVars loc (first (const ()) t) t + t' <- replaceTyVars loc t pure $ Wildcard (Info t') loc checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc (first (const ()) t1) t2 + t <- replaceTyVars loc t1 + unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, From cb21531f32edbe2226e760f3c9eca1ff1b97baf3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 16 Feb 2024 16:46:18 +0100 Subject: [PATCH 061/296] Fix typo. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6fefb80673..f2f7f8f327 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -601,7 +601,7 @@ checkApply loc _ ftype fframe arg = do m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty - m_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r) <> (toSComp <$> frameOf arg)) $ toType $ typeOf arg rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a ctAM r m From 4d673d3eba2777cb3d8a2976b7819224e9470b8a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 16 Feb 2024 17:51:26 +0100 Subject: [PATCH 062/296] Working AUTOMAP (not really). --- prelude/zip.fut | 2 +- .../Futhark/TypeChecker/Consumption.hs | 34 +++++++++-------- src/Language/Futhark/TypeChecker/Terms.hs | 38 ++++++++++++++----- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/prelude/zip.fut b/prelude/zip.fut index 1171820307..18361e545f 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -11,7 +11,7 @@ -- depended upon by soacs.fut. So we just define a quick-and-dirty -- internal one here that uses the intrinsic version. local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map f as + f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 5c1198537f..5ebe2996df 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -486,9 +486,10 @@ consumeAsNeeded loc (Scalar (Record fs1)) (Scalar (Record fs2)) = consumeAsNeeded loc pt t = when (diet pt == Consume) $ consumeAliases loc $ aliases t -checkArg :: [(Exp, TypeAliases)] -> ParamType -> Exp -> CheckM (Exp, TypeAliases) -checkArg prev p_t e = do - ((e', e_als), e_cons) <- contain $ checkExp e +checkArg :: [(Exp, TypeAliases)] -> ParamType -> AutoMap -> Exp -> CheckM (Exp, TypeAliases) +checkArg prev p_t am e = do + ((e', e_als), e_cons) <- + contain $ if autoRep am == mempty then noAliases e else checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ @@ -542,9 +543,11 @@ returnType appres (Scalar (Arrow _ v pd t1 (RetType dims t2))) Observe arg = returnType appres (Scalar (Sum cs)) d arg = Scalar $ Sum $ (fmap . fmap) (\et -> returnType appres et d arg) cs -applyArg :: TypeAliases -> TypeAliases -> TypeAliases -applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) arg_als = - returnType closure_als rettype d arg_als +applyArg :: TypeAliases -> (AutoMap, TypeAliases) -> TypeAliases +applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) (am, arg_als) = + if autoMap am /= mempty + then second (const mempty) rettype + else returnType closure_als rettype d arg_als applyArg t _ = error $ "applyArg: " <> show t boundFreeInExp :: Exp -> CheckM (M.Map VName TypeAliases) @@ -664,7 +667,7 @@ checkLoop loop_loc (param, arg, form, body) = do param' <- convergeLoopParam loop_loc param (M.keysSet body_cons) body_als let param_t = patternType param' - ((arg', arg_als), arg_cons) <- contain $ checkArg [] param_t arg + ((arg', arg_als), arg_cons) <- contain $ checkArg [] param_t mempty arg consumed arg_cons free_bound <- boundFreeInExp body @@ -685,7 +688,7 @@ checkLoop loop_loc (param, arg, form, body) = do `setAliases` S.singleton (AliasFree v) pure ( (param', arg', form', body'), - applyArg loopt arg_als `combineAliases` body_als + applyArg loopt (mempty, arg_als) `combineAliases` body_als ) checkFuncall :: @@ -693,7 +696,7 @@ checkFuncall :: SrcLoc -> Maybe (QualName VName) -> TypeAliases -> - f TypeAliases -> + f (AutoMap, TypeAliases) -> CheckM TypeAliases checkFuncall loc fname f_als arg_als = do v <- VName "internal_app_result" <$> incCounter @@ -707,15 +710,16 @@ checkExp :: Exp -> CheckM (Exp, TypeAliases) checkExp (AppExp (Apply f args loc) appres) = do (f', f_als) <- checkExp f (args', args_als) <- NE.unzip <$> checkArgs (toRes Nonunique f_als) args - res_als <- checkFuncall loc (fname f) f_als args_als + res_als <- + checkFuncall loc (fname f) f_als $ + NE.zip (fmap (snd . unInfo . fst) args') args_als pure (AppExp (Apply f' args' loc) appres, res_als) where - -- neUnzip3 xs = ((\(x, _, _) -> x) <$> xs, (\(_, y, _) -> y) <$> xs, (\(_, _, z) -> z) <$> xs) fname (Var v _ _) = Just v fname (AppExp (Apply e _ _) _) = fname e fname _ = Nothing checkArg' prev d (Info (p, am), e) = do - (e', e_als) <- checkArg prev (second (const d) (typeOf e)) e + (e', e_als) <- checkArg prev (second (const d) (typeOf e)) am e pure ((Info (p, am), e'), e_als) checkArgs (Scalar (Arrow _ _ d _ (RetType _ rt))) (x NE.:| args') = do @@ -807,9 +811,9 @@ checkExp (AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret), fu checkExp (AppExp (BinOp (op, oploc) opt (x, xp) (y, yp) loc) appres) = do op_als <- observeVar (locOf oploc) (qualLeaf op) (unInfo opt) let at1 : at2 : _ = fst $ unfoldFunType op_als - (x', x_als) <- checkArg [] at1 x - (y', y_als) <- checkArg [(x', x_als)] at2 y - res_als <- checkFuncall loc (Just op) op_als [x_als, y_als] + (x', x_als) <- checkArg [] at1 mempty x + (y', y_als) <- checkArg [(x', x_als)] at2 mempty y + res_als <- checkFuncall loc (Just op) op_als [(mempty, x_als), (mempty, y_als)] pure ( AppExp (BinOp (op, oploc) opt (x', xp) (y', yp) loc) appres, res_als diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 919ebda406..6510037fb6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -479,8 +479,8 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext) <- checkApply loc (Just op, 1) rt e2' + (_, rt, p1_ext, _, _) <- checkApply loc (Just op, 0) ftype e1' + (_, rt', p2_ext, retext, _) <- checkApply loc (Just op, 1) rt e2' pure $ AppExp @@ -548,10 +548,10 @@ checkExp (AppExp (Apply fe args loc) _) = do pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts where onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts) <- checkApply loc (fname, i) t arg' + (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' pure - ( (i + 1, all_exts <> exts, rt), - (Info (argext, mempty), arg') + ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), + (Info (argext, am), arg') ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e @@ -726,7 +726,7 @@ checkExp (OpSection op (Info op_t) loc) = do checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, _) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 rettype)) -> pure $ @@ -745,7 +745,7 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _) <- + (t2', arrow', argext, _, _) <- checkApply loc (Just op, 1) @@ -923,16 +923,27 @@ dimUses = flip execState mempty . traverseDims f where fv = freeInExp e `freeWithout` bound +-- | Try to find out how many dimensions of the argument we are +-- mapping. Returns the shape mapped and the remaining type. +stripToMatch :: StructType -> StructType -> (Shape Size, StructType) +stripToMatch paramt argt | toStructural paramt == toStructural argt = (mempty, argt) +stripToMatch paramt (Array _ (Shape (d : ds)) argt) = + first (Shape [d] <>) $ stripToMatch paramt $ arrayOf (Shape ds) (Scalar argt) +stripToMatch _ argt = (mempty, argt) + checkApply :: SrcLoc -> ApplyOp -> StructType -> Exp -> - TermTypeM (StructType, StructType, Maybe VName, [VName]) + TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do - unify (mkUsage argexp "use as function argument") tp1 argtype + (am_map_shape, argtype_automap) <- + stripToMatch <$> normTypeFully tp1 <*> normTypeFully argtype + + unify (mkUsage argexp "use as function argument") tp1 argtype_automap -- Perform substitutions of instantiated variables in the types. (tp2', ext) <- instantiateDimsInReturnType loc fname =<< normTypeFully tp2 @@ -972,7 +983,14 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do in pure (Nothing, applySubst parsubst $ toStruct tp2') _ -> pure (Nothing, toStruct tp2') - pure (tp1, tp2'', argext, ext) + let am = + AutoMap + { autoMap = am_map_shape, + autoRep = mempty, + autoFrame = am_map_shape + } + + pure (tp1, tp2'', argext, ext, am) checkApply loc fname tfun@(Scalar TypeVar {}) arg = do tv <- newTypeVar loc "b" unify (mkUsage loc "use as function") tfun $ From 8fd5e188ca49f8a1ecad90fd3178d992a23eacff Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 16 Feb 2024 14:40:52 -0800 Subject: [PATCH 063/296] Support AUTOMAP on `BinOp`s. --- src/Futhark/Internalise/FullNormalise.hs | 4 ++-- src/Futhark/Internalise/Monomorphise.hs | 6 +++--- src/Language/Futhark/Interpreter.hs | 2 +- src/Language/Futhark/Prop.hs | 4 ++-- src/Language/Futhark/Syntax.hs | 4 ++-- src/Language/Futhark/TypeChecker/Terms.hs | 18 +++++++++--------- src/Language/Futhark/TypeChecker/Terms2.hs | 4 ++-- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index fea8000abd..91e16a9a53 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -298,7 +298,7 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do While e -> While <$> transformBody e body' <- transformBody body nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT -getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) loc) (Info resT)) = do +getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do expr' <- case (isOr, isAnd) of (True, _) -> do el' <- naming "or_lhs" $ getOrdering True el @@ -311,7 +311,7 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) lo (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT nameExp final expr' where isOr = baseName (qualLeaf op) == "||" diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index e0d9834cfa..ee352f67c6 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -535,7 +535,7 @@ transformAppExp (Loop sparams pat e1 form body loc) res = do (pat_sizes, pat'') <- sizesForPat pat' res' <- transformAppRes res pure $ AppExp (Loop (sparams' ++ pat_sizes) pat'' e1' form' body' loc) (Info res') -transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do +transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, am1)) (e2, Info (d2, am2)) loc) res = do (AppRes ret ext) <- transformAppRes res fname' <- transformFName loc fname (toStruct t) e1' <- transformExp e1 @@ -570,8 +570,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(unInfo d1, mempty, x)] (AppRes ret mempty)) - [(unInfo d2, mempty, y)] + (mkApply fname' [(d1, am1, x)] (AppRes ret mempty)) + [(d2, am2, y)] (AppRes ret ext) makeVarParam arg = do diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index ac76cf6645..01e4bddb21 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -828,7 +828,7 @@ evalAppExp env (LetPat sizes p e body _) = do evalAppExp env (LetFun f (tparams, ps, _, Info ret, fbody) body _) = do binding <- evalFunctionBinding env tparams ps ret fbody eval (env {envTerm = M.insert f binding $ envTerm env}) body -evalAppExp env (BinOp (op, _) op_t (x, Info xext) (y, Info yext) loc) +evalAppExp env (BinOp (op, _) op_t (x, Info (xext, xam)) (y, Info (yext, yam)) loc) | baseString (qualLeaf op) == "&&" = do x' <- asBool <$> eval env x if x' diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index cc5c40268a..419e50ba6b 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -682,8 +682,8 @@ mkBinOp op t x y = ( BinOp (qualName (intrinsicVar op), mempty) (Info t) - (x, Info Nothing) - (y, Info Nothing) + (x, Info (Nothing, mempty)) + (y, Info (Nothing, mempty)) mempty ) (Info $ AppRes t []) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index f0f2a586df..bd2133f017 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -744,8 +744,8 @@ data AppExpBase f vn | BinOp (QualName vn, SrcLoc) (f StructType) - (ExpBase f vn, f (Maybe VName)) - (ExpBase f vn, f (Maybe VName)) + (ExpBase f vn, f (Maybe VName, AutoMap)) + (ExpBase f vn, f (Maybe VName, AutoMap)) SrcLoc | LetWith (IdentBase f vn StructType) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 6510037fb6..967919c52d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -170,8 +170,8 @@ sliceShape r slice t@(Array u (Shape orig_dims) et) = ( BinOp (qualName (intrinsicVar "-"), mempty) sizeBinOpInfo - (j, Info Nothing) - (i, Info Nothing) + (j, Info (Nothing, mempty)) + (i, Info (Nothing, mempty)) mempty ) $ Info @@ -454,8 +454,8 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do ( BinOp (qualName (intrinsicVar op), mempty) sizeBinOpInfo - (x, Info Nothing) - (y, Info Nothing) + (x, Info (Nothing, mempty)) + (y, Info (Nothing, mempty)) mempty ) (Info $ AppRes t []) @@ -479,16 +479,16 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _, _) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext, _) <- checkApply loc (Just op, 1) rt e2' + (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' pure $ AppExp ( BinOp (op, oploc) (Info ftype) - (e1', Info p1_ext) - (e2', Info p2_ext) + (e1', Info (p1_ext, am1)) + (e2', Info (p2_ext, am2)) loc ) (Info (AppRes rt' retext)) @@ -1143,7 +1143,7 @@ causalityCheck binding_body = do modify (new_known <>) onExp known - e@(AppExp (BinOp (f, floc) ft (x, Info xp) (y, Info yp) _) (Info res)) = do + e@(AppExp (BinOp (f, floc) ft (x, Info (xp, _)) (y, Info (yp, _)) _) (Info res)) = do args_known <- collectingNewKnown $ sequencePoint known x y $ catMaybes [xp, yp] void $ onExp (args_known <> known) (Var f ft floc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index f2f7f8f327..71318292ef 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -835,12 +835,12 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do e2' <- checkExp e2 (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) mempty e1' - (rt2, am2) <- checkApply loc (Just op, 1) rt1 mempty e2' + (rt2, am2) <- checkApply loc (Just op, 1) rt1 (autoFrame am1) e2' rt2' <- asStructType loc rt2 pure $ AppExp - (BinOp (op, oploc) (Info ftype) (e1', Info Nothing) (e2', Info Nothing) loc) + (BinOp (op, oploc) (Info ftype) (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) (Info (AppRes rt2' [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do From 32c9a73f7d62b277e8f21a4b88f1f6447051f92e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 14:07:33 +0100 Subject: [PATCH 064/296] Some work on overloaded type variables. --- .../Futhark/TypeChecker/Constraints.hs | 8 ++-- src/Language/Futhark/TypeChecker/Terms.hs | 1 + src/Language/Futhark/TypeChecker/Terms2.hs | 45 ++++++++++++++++--- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 665e67af3f..2548c74008 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -118,18 +118,16 @@ substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt -- | A solution maps a type variable to its substitution. This substitution is complete, in the sense there are no right-hand sides that contain a type variable. -type Solution = M.Map TyVar (TypeBase () NoUniqueness) +type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) solution :: SolverState -> Solution solution s = M.mapMaybe mkSubst $ solverTyVars s where - mkSubst (TyVarSol _lvl t) = Just $ first (const ()) $ substTyVars (solverTyVars s) t + mkSubst (TyVarSol _lvl t) = Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t mkSubst (TyVarLink v') = mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (TyVarUnsol _ (TyVarPrim pts)) - | Signed Int32 `elem` pts = - Just (Scalar (Prim (Signed Int32))) -- XXX - we need warnings and things! + mkSubst (TyVarUnsol _ (TyVarPrim pts)) = Just $ Left pts mkSubst _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 967919c52d..4e0ea57060 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1666,6 +1666,7 @@ checkFunDef :: checkFunDef (fname, retdecl, tparams, params, body, loc) = do (maybe_tysubsts, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right tysubsts -> runTermTypeM checkExp tysubsts $ do diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 71318292ef..52933dd1a3 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -49,8 +49,10 @@ import Data.Bifunctor import Data.Char (isAscii) import Data.List qualified as L import Data.List.NonEmpty qualified as NE +import Data.Loc (Loc (NoLoc)) import Data.Map qualified as M import Data.Maybe +import Data.Set qualified as S import Data.Text qualified as T import Debug.Trace import Futhark.FreshNames qualified as FreshNames @@ -63,7 +65,7 @@ import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Rank import Language.Futhark.TypeChecker.Types -import Language.Futhark.TypeChecker.Unify (Level) +import Language.Futhark.TypeChecker.Unify (Level, mkUsage) import Prelude hiding (mod) data Inferred t @@ -430,7 +432,10 @@ patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v checkSizeExp :: ExpBase NoInfo VName -> TermM Exp -checkSizeExp = require "use as size" [Signed Int64] <=< checkExp +checkSizeExp e = do + e' <- checkExp e + ctEq (expType e') (Scalar (Prim (Signed Int64))) + pure e' checkPat' :: PatBase NoInfo VName ParamType -> @@ -1057,6 +1062,30 @@ checkExp (Coerce e te NoInfo loc) = do ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc +doDefaults :: + S.Set VName -> + VName -> + Either [PrimType] (TypeBase () NoUniqueness) -> + TermM (TypeBase () NoUniqueness) +doDefaults tyvars_at_toplevel v (Left pts) + | Signed Int32 `elem` pts = do + when (v `S.member` tyvars_at_toplevel) $ + warn usage "Defaulting ambiguous type to i32." + pure $ Scalar $ Prim $ Signed Int32 + | FloatType Float64 `elem` pts = do + when (v `S.member` tyvars_at_toplevel) $ + warn usage "Defaulting ambiguous type to f64." + pure $ Scalar $ Prim $ FloatType Float64 + | otherwise = + typeError usage mempty . withIndexLink "ambiguous-type" $ + "Type is ambiguous (could be one of" + <+> commasep (map pretty pts) + <> ")." + "Add a type annotation to disambiguate the type." + where + usage = mkUsage NoLoc "overload" +doDefaults _ _ (Right t) = pure t + checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1066,7 +1095,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text Solution, + ( Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1099,7 +1128,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do Nothing -> error "" Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} - let solution = solve cts' tyvars' + + solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' traceM $ unlines @@ -1111,12 +1141,15 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do let p (v, t) = prettyNameString v <> " => " <> prettyString t in either T.unpack (unlines . map p . M.toList) solution ] + pure (solution, params', retdecl', body') -checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either T.Text Solution, Exp) +checkSingleExp :: + ExpBase NoInfo VName -> + TypeM (Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - let solution = solve cts tyvars + solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts tyvars pure (solution, e') From 5f45f198ff8819d425388260652c7f5eceff8548 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 14:14:40 +0100 Subject: [PATCH 065/296] Tear out some organs we probably do not need anymore. --- src/Language/Futhark/TypeChecker/Terms.hs | 25 +----- .../Futhark/TypeChecker/Terms/Loop.hs | 4 +- .../Futhark/TypeChecker/Terms/Monad.hs | 27 +----- src/Language/Futhark/TypeChecker/Unify.hs | 84 ------------------- 4 files changed, 9 insertions(+), 131 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 4e0ea57060..58e1b009c3 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -390,7 +390,7 @@ checkExp (ArrayLit all_es _ loc) = t <- arrayOfM loc et (Shape [sizeFromInteger (genericLength all_es) mempty]) pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do - start' <- require "use in range expression" anySignedType =<< checkExp start + start' <- checkExp start start_t <- expType start' maybe_step' <- case maybe_step of Nothing -> pure Nothing @@ -530,10 +530,10 @@ checkExp (Var qn (Info t) loc) = do t' <- lookupVar loc qn t pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do - arg' <- require "numeric negation" anyNumberType =<< checkExp arg + arg' <- checkExp arg pure $ Negate arg' loc checkExp (Not arg loc) = do - arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg + arg' <- checkExp arg pure $ Not arg' loc checkExp (AppExp (Apply fe args loc) _) = do fe' <- checkExp fe @@ -662,7 +662,7 @@ checkExp (AppExp (Index e slice loc) _) = do pure $ AppExp (Index e' slice' loc) (Info $ AppRes t' retext) checkExp (Assert e1 e2 _ loc) = do - e1' <- require "being asserted" [Bool] =<< checkExp e1 + e1' <- checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do @@ -1286,23 +1286,6 @@ fixOverloadedTypes :: Names -> TermTypeM () fixOverloadedTypes tyvars_at_toplevel = getConstraints >>= mapM_ fixOverloaded . M.toList . M.map snd where - fixOverloaded (v, Overloaded ots usage) - | Signed Int32 `elem` ots = do - unify usage (Scalar (TypeVar mempty (qualName v) [])) $ - Scalar (Prim $ Signed Int32) - when (v `S.member` tyvars_at_toplevel) $ - warn usage "Defaulting ambiguous type to i32." - | FloatType Float64 `elem` ots = do - unify usage (Scalar (TypeVar mempty (qualName v) [])) $ - Scalar (Prim $ FloatType Float64) - when (v `S.member` tyvars_at_toplevel) $ - warn usage "Defaulting ambiguous type to f64." - | otherwise = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (could be one of" - <+> commasep (map pretty ots) - <> ")." - "Add a type annotation to disambiguate the type." fixOverloaded (v, NoConstraint _ usage) = do -- See #1552. unify usage (Scalar (TypeVar mempty (qualName v) [])) $ diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 349b105823..7cba8af7e8 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -220,9 +220,7 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do (sparams, mergepat', form', loopbody') <- case form of For i uboundexp -> do - uboundexp' <- - require "being the bound in a 'for' loop" anySignedType - =<< checkExp uboundexp + uboundexp' <- checkExp uboundexp bindingIdent i . bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 87697bc359..8bdbb81daf 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -31,7 +31,6 @@ module Language.Futhark.TypeChecker.Terms.Monad -- * Primitive checking unifies, - require, checkTypeExpNonrigid, lookupVar, lookupMod, @@ -537,21 +536,10 @@ lookupVar loc qn@(QualName qs name) inst_t = do (tnames, t) <- instTypeScheme qn loc tparams bound_t $ first (const ()) inst_t outer_env <- asks termOuterEnv pure $ qualifyTypeVars outer_env tnames qs t - Just EqualityF -> do - argtype <- newTypeVar loc "t" - equalityType usage argtype - pure $ - Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ - Scalar $ - Arrow mempty Unnamed Observe argtype $ - RetType [] $ - Scalar $ - Prim Bool - Just (OverloadedF ts pts rt) -> do - argtype <- newTypeVar loc "t" - mustBeOneOf ts usage argtype - let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + Just EqualityF -> + replaceTyVars loc inst_t + Just OverloadedF {} -> + replaceTyVars loc inst_t where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, @@ -639,13 +627,6 @@ unifies why t e = do unify (mkUsage (srclocOf e) why) t . toStruct =<< expType e pure e --- | @require ts e@ causes a 'TypeError' if @expType e@ is not one of --- the types in @ts@. Otherwise, simply returns @e@. -require :: T.Text -> [PrimType] -> Exp -> TermTypeM Exp -require why ts e = do - mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e - pure e - checkExpForSize :: ExpBase Info VName -> TermTypeM Exp checkExpForSize e = do checker <- asks termCheckExp diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 1d26af8354..8e1c414e6c 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -18,7 +18,6 @@ module Language.Futhark.TypeChecker.Unify arrayElemType, mustHaveConstr, mustHaveField, - mustBeOneOf, equalityType, normType, normTypeFully, @@ -119,7 +118,6 @@ data Constraint = NoConstraint Liftedness Usage | ParamType Liftedness Loc | Constraint StructRetType Usage - | Overloaded [PrimType] Usage | HasFields Liftedness (M.Map Name StructType) Usage | Equality Usage | HasConstrs Liftedness (M.Map Name [StructType]) Usage @@ -138,7 +136,6 @@ instance Located Constraint where locOf (NoConstraint _ usage) = locOf usage locOf (ParamType _ usage) = locOf usage locOf (Constraint _ usage) = locOf usage - locOf (Overloaded _ usage) = locOf usage locOf (HasFields _ _ usage) = locOf usage locOf (Equality usage) = locOf usage locOf (HasConstrs _ _ usage) = locOf usage @@ -282,8 +279,6 @@ typeVarNotes v = maybe mempty (note . snd) . M.lookup v <$> getConstraints <+> "=" <+> hsep (map ppConstr (M.toList cs)) <+> "..." - note (Overloaded ts _) = - aNote $ prettyName v <+> "must be one of" <+> mconcat (punctuate ", " (map pretty ts)) note (HasFields _ fs _) = aNote $ prettyName v @@ -685,26 +680,6 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do Just (Equality _) -> do link equalityType usage tp - Just (Overloaded ts old_usage) - | tp `notElem` map (Scalar . Prim) ts -> do - link - case tp of - Scalar (TypeVar _ (QualName [] v) []) - | not $ isRigid v constraints -> - linkVarToTypes usage v ts - _ -> - unifyError usage mempty bcs $ - "Cannot instantiate" - <+> dquotes (prettyName vn) - <+> "with type" - indent 2 (pretty tp) - "as" - <+> dquotes (prettyName vn) - <+> "must be one of" - <+> commasep (map pretty ts) - "due to" - <+> pretty old_usage - <> "." Just (HasFields l required_fields old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp case tp of @@ -846,63 +821,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () --- | Assert that this type must be one of the given primitive types. -mustBeOneOf :: (MonadUnify m) => [PrimType] -> Usage -> StructType -> m () -mustBeOneOf [req_t] usage t = unify usage (Scalar (Prim req_t)) t -mustBeOneOf ts usage t = do - t' <- normType t - constraints <- getConstraints - let isRigid' v = isRigid v constraints - - case t' of - Scalar (TypeVar _ (QualName [] v) []) - | not $ isRigid' v -> linkVarToTypes usage v ts - Scalar (Prim pt) | pt `elem` ts -> pure () - _ -> failure - where - failure = - unifyError usage mempty noBreadCrumbs $ - "Cannot unify type" - <+> dquotes (pretty t) - <+> "with any of " - <> commasep (map pretty ts) - <> "." - -linkVarToTypes :: (MonadUnify m) => Usage -> VName -> [PrimType] -> m () -linkVarToTypes usage vn ts = do - vn_constraint <- M.lookup vn <$> getConstraints - case vn_constraint of - Just (lvl, Overloaded vn_ts vn_usage) -> - case ts `intersect` vn_ts of - [] -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <+> "but also one of" - <+> commasep (map pretty vn_ts) - <+> "due to" - <+> pretty vn_usage - <> "." - ts' -> modifyConstraints $ M.insert vn (lvl, Overloaded ts' usage) - Just (_, HasConstrs _ _ vn_usage) -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <> ", but also inferred to be sum type due to" - <+> pretty vn_usage - <> "." - Just (_, HasFields _ _ vn_usage) -> - unifyError usage mempty noBreadCrumbs $ - "Type constrained to one of" - <+> commasep (map pretty ts) - <> ", but also inferred to be record due to" - <+> pretty vn_usage - <> "." - Just (lvl, _) -> modifyConstraints $ M.insert vn (lvl, Overloaded ts usage) - Nothing -> - unifyError usage mempty noBreadCrumbs $ - "Cannot constrain type to one of" <+> commasep (map pretty ts) - -- | Assert that this type must support equality. equalityType :: (MonadUnify m, Pretty (Shape dim), Pretty u) => @@ -932,8 +850,6 @@ equalityType usage t = do | otherwise -> pure () Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, Equality usage) - Just (_, Overloaded _ _) -> - pure () -- All primtypes support equality. Just (_, Equality {}) -> pure () _ -> From ae1f529e544854f30d090360bb4647ed1b930061 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 14:43:06 +0100 Subject: [PATCH 066/296] Remove more guts. --- src/Language/Futhark/TypeChecker/Terms.hs | 49 +--- src/Language/Futhark/TypeChecker/Unify.hs | 310 +--------------------- 2 files changed, 19 insertions(+), 340 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 58e1b009c3..08c7585de6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -56,12 +56,6 @@ hasBinding e = isNothing $ astMap m e m = identityMapper {mapOnExp = \e' -> if hasBinding e' then Nothing else Just e'} -overloadedTypeVars :: Constraints -> Names -overloadedTypeVars = mconcat . map f . M.elems - where - f (_, HasFields _ fs _) = mconcat $ map typeVars $ M.elems fs - f _ = mempty - --- Basic checking -- | Determine if the two types are identical, ignoring uniqueness. @@ -495,8 +489,11 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' - kt <- mustHaveField (mkUsage loc $ docText $ "projection of field " <> dquotes (pretty k)) k t - pure $ Project k e' (Info kt) loc + case t of + Scalar (Record fs) + | Just kt <- M.lookup k fs -> + pure $ Project k e' (Info kt) loc + _ -> error $ "checkExp Project: " <> show t checkExp (AppExp (If e1 e2 e3 loc) _) = do e1' <- checkExp e1 e2' <- checkExp e2 @@ -765,12 +762,9 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (ProjectSection fields _ loc) = do - a <- newTypeVar loc "a" - let usage = mkUsage loc "projection at" - b <- foldM (flip $ mustHaveField usage) a fields - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ toRes Nonunique b - pure $ ProjectSection fields (Info ft) loc +checkExp (ProjectSection fields (Info t) loc) = do + t' <- replaceTyVars loc t + pure $ ProjectSection fields (Info t') loc checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' @@ -1292,23 +1286,6 @@ fixOverloadedTypes tyvars_at_toplevel = Scalar (tupleRecord []) when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to ()." - fixOverloaded (_, Equality usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (must be equality type)." - "Add a type annotation to disambiguate the type." - fixOverloaded (_, HasFields _ fs usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous. Must be record with fields:" - indent 2 (stack $ map field $ M.toList fs) - "Add a type annotation to disambiguate the type." - where - field (l, t) = pretty l <> colon <+> align (pretty t) - fixOverloaded (_, HasConstrs _ cs usage) = - typeError usage mempty . withIndexLink "ambiguous-type" $ - "Type is ambiguous (must be a sum type with constructors:" - <+> pretty (Sum cs) - <> ")." - "Add a type annotation to disambiguate the type." fixOverloaded (v, Size Nothing (Usage Nothing loc)) = typeError loc mempty . withIndexLink "ambiguous-size" $ "Ambiguous size" <+> dquotes (prettyName v) <> "." @@ -1552,18 +1529,12 @@ letGeneralise defname defloc tparams params restype = -- -- (2) are not used in the (new) definition of any type variables -- known before we checked this function. - -- - -- (3) are not referenced from an overloaded type (for example, - -- are the element types of an incompletely resolved record type). - -- This is a bit more restrictive than I'd like, and SML for - -- example does not have this restriction. - -- + -- Criteria (1) and (2) is implemented by looking at the binding -- level of the type variables. - let keep_type_vars = overloadedTypeVars now_substs cur_lvl <- curLevel - let candidate k (lvl, _) = (k `S.notMember` keep_type_vars) && lvl >= (cur_lvl - length params) + let candidate k (lvl, _) = lvl >= (cur_lvl - length params) new_substs = M.filterWithKey candidate now_substs (tparams', RetType ret_dims restype') <- diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 8e1c414e6c..4493b02b2d 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -16,9 +16,6 @@ module Language.Futhark.TypeChecker.Unify dimNotes, zeroOrderType, arrayElemType, - mustHaveConstr, - mustHaveField, - equalityType, normType, normTypeFully, unify, @@ -30,7 +27,7 @@ where import Control.Monad import Control.Monad.Except import Control.Monad.State -import Data.List (foldl', intersect) +import Data.List (foldl') import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S @@ -118,9 +115,6 @@ data Constraint = NoConstraint Liftedness Usage | ParamType Liftedness Loc | Constraint StructRetType Usage - | HasFields Liftedness (M.Map Name StructType) Usage - | Equality Usage - | HasConstrs Liftedness (M.Map Name [StructType]) Usage | ParamSize Loc | -- | Is not actually a type, but a term-level size, -- possibly already set to something specific. @@ -136,9 +130,6 @@ instance Located Constraint where locOf (NoConstraint _ usage) = locOf usage locOf (ParamType _ usage) = locOf usage locOf (Constraint _ usage) = locOf usage - locOf (HasFields _ _ usage) = locOf usage - locOf (Equality usage) = locOf usage - locOf (HasConstrs _ _ usage) = locOf usage locOf (ParamSize loc) = locOf loc locOf (Size _ usage) = locOf usage locOf (UnknownSize loc _) = locOf loc @@ -270,25 +261,6 @@ typeNotes ctx = . fvVars . freeInType -typeVarNotes :: (MonadUnify m) => VName -> m Notes -typeVarNotes v = maybe mempty (note . snd) . M.lookup v <$> getConstraints - where - note (HasConstrs _ cs _) = - aNote $ - prettyName v - <+> "=" - <+> hsep (map ppConstr (M.toList cs)) - <+> "..." - note (HasFields _ fs _) = - aNote $ - prettyName v - <+> "=" - <+> braces (mconcat (punctuate ", " (map ppField (M.toList fs)))) - note _ = mempty - - ppConstr (c, _) = "#" <> pretty c <+> "..." <+> "|" - ppField (f, _) = prettyName f <> ":" <+> "..." - -- | Monads that which to perform unification must implement this type -- class. class (Monad m) => MonadUnify m where @@ -354,12 +326,6 @@ unsharedConstructorsMsg cs1 cs2 = filter (`notElem` M.keys cs1) (M.keys cs2) ++ filter (`notElem` M.keys cs2) (M.keys cs1) --- | Is the given type variable the name of an abstract type or type --- parameter, which we cannot substitute? -isRigid :: VName -> Constraints -> Bool -isRigid v constraints = - maybe True (rigidConstraint . snd) $ M.lookup v constraints - -- | If the given type variable is nonrigid, what is its level? isNonRigid :: VName -> Constraints -> Maybe Level isNonRigid v constraints = do @@ -370,10 +336,6 @@ isNonRigid v constraints = do type UnifySizes m = BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> Exp -> Exp -> m () -flipUnifySizes :: UnifySizes m -> UnifySizes m -flipUnifySizes onDims bcs bound nonrigid t1 t2 = - onDims bcs bound nonrigid t2 t1 - unifyWith :: (MonadUnify m) => UnifySizes m -> @@ -398,14 +360,7 @@ unifyWith onDims usage = subunify False failure = matchError (srclocOf usage) mempty bcs t1' t2' - link ord' = - linkVarToType linkDims usage bound bcs - where - -- We may have to flip the order of future calls to - -- onDims inside linkVarToType. - linkDims - | ord' = flipUnifySizes onDims - | otherwise = onDims + link = linkVarToType usage bound bcs unifyTypeArg bcs' (TypeArgDim d1) (TypeArgDim d2) = onDims' bcs' (swap ord d1 d2) @@ -452,17 +407,17 @@ unifyWith onDims usage = subunify False ) -> case (nonrigid v1, nonrigid v2) of (Nothing, Nothing) -> failure - (Just lvl1, Nothing) -> link ord v1 lvl1 t2' - (Nothing, Just lvl2) -> link (not ord) v2 lvl2 t1' + (Just lvl1, Nothing) -> link v1 lvl1 t2' + (Nothing, Just lvl2) -> link v2 lvl2 t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> link ord v1 lvl1 t2' - | otherwise -> link (not ord) v2 lvl2 t1' + | lvl1 <= lvl2 -> link v1 lvl1 t2' + | otherwise -> link v2 lvl2 t1' (Scalar (TypeVar _ (QualName [] v1) []), _) | Just lvl <- nonrigid v1 -> - link ord v1 lvl t2' + link v1 lvl t2' (_, Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- nonrigid v2 -> - link (not ord) v2 lvl t1' + link v2 lvl t1' ( Scalar (Arrow _ p1 d1 a1 (RetType b1_dims b1)), Scalar (Arrow _ p2 d2 a2 (RetType b2_dims b2)) ) @@ -625,7 +580,6 @@ scopeCheck usage bcs vn max_lvl tp = do linkVarToType :: (MonadUnify m) => - UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -633,7 +587,7 @@ linkVarToType :: Level -> StructType -> m () -linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do +linkVarToType usage bound bcs vn lvl tp_unnorm = do -- We have to expand anyway for the occurs check, so we might as -- well link the fully expanded type. tp <- normTypeFully tp_unnorm @@ -677,105 +631,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do <+> "cannot be instantiated with type containing anonymous sizes:" indent 2 (pretty tp) textwrap "This is usually because the size of an array returned by a higher-order function argument cannot be determined statically. This can also be due to the return size being a value parameter. Add type annotation to clarify." - Just (Equality _) -> do - link - equalityType usage tp - Just (HasFields l required_fields old_usage) -> do - when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of - Scalar (Record tp_fields) - | all (`M.member` tp_fields) $ M.keys required_fields -> do - required_fields' <- mapM normTypeFully required_fields - let tp' = Scalar $ Record $ required_fields <> tp_fields -- Crucially left-biased. - ext = filter (`S.member` fvVars (freeInType tp')) bound - modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp') usage) - unifySharedFields onDims usage bound bcs required_fields' tp_fields - Scalar (TypeVar _ (QualName [] v) []) -> do - case M.lookup v constraints of - Just (_, HasFields _ tp_fields _) -> - unifySharedFields onDims usage bound bcs required_fields tp_fields - Just (_, NoConstraint {}) -> pure () - Just (_, Equality {}) -> pure () - _ -> do - notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v - noRecordType notes - link - modifyConstraints $ - M.insertWith - combineFields - v - (lvl, HasFields l required_fields old_usage) - where - combineFields (_, HasFields l1 fs1 usage1) (_, HasFields l2 fs2 _) = - (lvl, HasFields (l1 `min` l2) (M.union fs1 fs2) usage1) - combineFields hasfs _ = hasfs - _ -> - unifyError usage mempty bcs $ - "Cannot instantiate" - <+> dquotes (prettyName vn) - <+> "with type" - indent 2 (pretty tp) - "as" - <+> dquotes (prettyName vn) - <+> "must be a record with fields" - indent 2 (pretty (Record required_fields)) - "due to" - <+> pretty old_usage - <> "." - -- See Note [Linking variables to sum types] - Just (HasConstrs l required_cs old_usage) -> do - when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp - case tp of - Scalar (Sum ts) - | all (`M.member` ts) $ M.keys required_cs -> do - let tp' = Scalar $ Sum $ required_cs <> ts -- Crucially left-biased. - ext = filter (`S.member` fvVars (freeInType tp')) bound - modifyConstraints $ - M.insert vn (lvl, Constraint (RetType ext tp') usage) - unifySharedConstructors onDims usage bound bcs required_cs ts - | otherwise -> - unsharedConstructors required_cs ts =<< typeVarNotes vn - Scalar (TypeVar _ (QualName [] v) []) -> do - case M.lookup v constraints of - Just (_, HasConstrs _ v_cs _) -> - unifySharedConstructors onDims usage bound bcs required_cs v_cs - Just (_, NoConstraint {}) -> pure () - Just (_, Equality {}) -> pure () - _ -> do - notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v - noSumType notes - link - modifyConstraints $ - M.insertWith - combineConstrs - v - (lvl, HasConstrs l required_cs old_usage) - where - combineConstrs (_, HasConstrs l1 cs1 usage1) (_, HasConstrs l2 cs2 _) = - (lvl, HasConstrs (l1 `min` l2) (M.union cs1 cs2) usage1) - combineConstrs hasCs _ = hasCs - _ -> noSumType =<< typeVarNotes vn _ -> link - where - unsharedConstructors cs1 cs2 notes = - unifyError - usage - notes - bcs - (unsharedConstructorsMsg cs1 cs2) - noSumType notes = - unifyError - usage - notes - bcs - "Cannot unify a sum type with a non-sum type." - noRecordType notes = - unifyError - usage - notes - bcs - "Cannot unify a record type with a non-record type." linkVarToDim :: (MonadUnify m) => @@ -821,41 +677,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () --- | Assert that this type must support equality. -equalityType :: - (MonadUnify m, Pretty (Shape dim), Pretty u) => - Usage -> - TypeBase dim u -> - m () -equalityType usage t = do - unless (orderZero t) $ - unifyError usage mempty noBreadCrumbs $ - "Type " <+> dquotes (pretty t) <+> "does not support equality (may contain function)." - mapM_ mustBeEquality $ typeVars t - where - mustBeEquality vn = do - constraints <- getConstraints - case M.lookup vn constraints of - Just (_, Constraint (RetType [] (Scalar (TypeVar _ (QualName [] vn') []))) _) -> - mustBeEquality vn' - Just (_, Constraint (RetType _ vn_t) cusage) - | not $ orderZero vn_t -> - unifyError usage mempty noBreadCrumbs $ - "Type" - <+> dquotes (pretty t) - <+> "does not support equality." - "Constrained to be higher-order due to" - <+> pretty cusage - <+> "." - | otherwise -> pure () - Just (lvl, NoConstraint _ _) -> - modifyConstraints $ M.insert vn (lvl, Equality usage) - Just (_, Equality {}) -> - pure () - _ -> - unifyError usage mempty noBreadCrumbs $ - "Type" <+> prettyName vn <+> "does not support equality." - zeroOrderTypeWith :: (MonadUnify m) => Usage -> @@ -873,10 +694,6 @@ zeroOrderTypeWith usage bcs t = do case M.lookup vn constraints of Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) - Just (lvl, HasFields _ fs _) -> - modifyConstraints $ M.insert vn (lvl, HasFields Unlifted fs usage) - Just (lvl, HasConstrs _ cs _) -> - modifyConstraints $ M.insert vn (lvl, HasConstrs Unlifted cs usage) Just (_, ParamType Lifted ploc) -> unifyError usage mempty bcs $ "Type parameter" @@ -967,96 +784,6 @@ unifySharedConstructors onDims usage bound bcs cs1 cs2 = unifyError usage mempty bcs $ "Cannot unify constructor" <+> dquotes (prettyName c) <> "." --- | In @mustHaveConstr usage c t fs@, the type @t@ must have a --- constructor named @c@ that takes arguments of types @ts@. -mustHaveConstr :: - (MonadUnify m) => - Usage -> - Name -> - StructType -> - [StructType] -> - m () -mustHaveConstr usage c t fs = do - constraints <- getConstraints - case t of - Scalar (TypeVar _ (QualName _ tn) []) - | Just (lvl, NoConstraint l _) <- M.lookup tn constraints -> do - mapM_ (scopeCheck usage noBreadCrumbs tn lvl) fs - modifyConstraints $ M.insert tn (lvl, HasConstrs l (M.singleton c fs) usage) - | Just (lvl, HasConstrs l cs _) <- M.lookup tn constraints -> - case M.lookup c cs of - Nothing -> - modifyConstraints $ - M.insert tn (lvl, HasConstrs l (M.insert c fs cs) usage) - Just fs' - | length fs == length fs' -> zipWithM_ (unify usage) fs fs' - | otherwise -> - unifyError usage mempty noBreadCrumbs $ - "Different arity for constructor" <+> dquotes (pretty c) <> "." - Scalar (Sum cs) -> - case M.lookup c cs of - Nothing -> - unifyError usage mempty noBreadCrumbs $ - "Constuctor" <+> dquotes (pretty c) <+> "not present in type." - Just fs' - | length fs == length fs' -> zipWithM_ (unify usage) fs fs' - | otherwise -> - unifyError usage mempty noBreadCrumbs $ - "Different arity for constructor" <+> dquotes (pretty c) <+> "." - _ -> - unify usage t $ Scalar $ Sum $ M.singleton c fs - -mustHaveFieldWith :: - (MonadUnify m) => - UnifySizes m -> - Usage -> - [VName] -> - BreadCrumbs -> - Name -> - StructType -> - m StructType -mustHaveFieldWith onDims usage bound bcs l t = do - constraints <- getConstraints - l_type <- newTypeVar (locOf usage) "t" - case t of - Scalar (TypeVar _ (QualName _ tn) []) - | Just (lvl, NoConstraint {}) <- M.lookup tn constraints -> do - scopeCheck usage bcs tn lvl l_type - modifyConstraints $ M.insert tn (lvl, HasFields Lifted (M.singleton l l_type) usage) - pure l_type - | Just (lvl, HasFields lifted fields _) <- M.lookup tn constraints -> do - case M.lookup l fields of - Just t' -> unifyWith onDims usage bound bcs l_type t' - Nothing -> - modifyConstraints $ - M.insert - tn - (lvl, HasFields lifted (M.insert l l_type fields) usage) - pure l_type - Scalar (Record fields) - | Just t' <- M.lookup l fields -> do - unify usage l_type t' - pure t' - | otherwise -> - unifyError usage mempty bcs $ - "Attempt to access field" - <+> dquotes (pretty l) - <+> " of value of type" - <+> pretty (toStructural t) - <> "." - _ -> do - unify usage t $ Scalar $ Record $ M.singleton l l_type - pure l_type - --- | Assert that some type must have a field with this name and type. -mustHaveField :: - (MonadUnify m) => - Usage -> - Name -> - StructType -> - m StructType -mustHaveField usage = mustHaveFieldWith (unifySizes usage) usage mempty noBreadCrumbs - newDimOnMismatch :: (MonadUnify m) => Loc -> @@ -1180,22 +907,3 @@ doUnification loc rigid_tparams nonrigid_tparams t1 t2 = runUnifyM rigid_tparams nonrigid_tparams $ do unify (Usage Nothing (locOf loc)) t1 t2 normTypeFully t2 - --- Note [Linking variables to sum types] --- --- Consider the case when unifying a result type --- --- i32 -> ?[n].(#foo [n]bool) --- --- with --- --- i32 -> ?[k].a --- --- where 'a' has a HasConstrs constraint saying that it must have at --- least a constructor of type '#foo [0]bool'. --- --- This unification should succeed, but we must not merely link 'a' to --- '#foo [n]bool', as 'n' is not free. Instead we should instantiate --- 'a' to be a concrete sum type (because now we know exactly which --- constructor labels it must have), and unify each of its constructor --- payloads with the corresponding expected payload. From 0f6156a2dc914c94c53215f8517e8928ccf3c619 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 18:15:38 +0100 Subject: [PATCH 067/296] Must also unify here. --- src/Language/Futhark/TypeChecker/Terms.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 08c7585de6..28e2e9f2e5 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -676,6 +676,8 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' + unify (mkUsage body "inferred return type") (toStruct rt') body_t + params'' <- mapM updateTypes params' (rettype', rettype_st) <- @@ -1534,8 +1536,8 @@ letGeneralise defname defloc tparams params restype = -- level of the type variables. cur_lvl <- curLevel - let candidate k (lvl, _) = lvl >= (cur_lvl - length params) - new_substs = M.filterWithKey candidate now_substs + let candidate (lvl, _) = lvl >= (cur_lvl - length params) + new_substs = M.filter candidate now_substs (tparams', RetType ret_dims restype') <- closeOverTypes From 9bf72c86dbfb95bfb9106d2b9e508b9702895de6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 17 Feb 2024 18:50:48 +0100 Subject: [PATCH 068/296] AUTOMAP does not work yet. --- prelude/zip.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prelude/zip.fut b/prelude/zip.fut index 18361e545f..1171820307 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -11,7 +11,7 @@ -- depended upon by soacs.fut. So we just define a quick-and-dirty -- internal one here that uses the intrinsic version. local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - f as + intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = From 7d3bcc7e46942e08f1343f2590576021179baf4d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sat, 17 Feb 2024 21:11:31 -0800 Subject: [PATCH 069/296] Defunctionalization and internalization AUTOMAP progress. --- src/Futhark/IR/Syntax/Core.hs | 6 ++ src/Futhark/Internalise/Defunctionalise.hs | 64 +++++++++-- src/Futhark/Internalise/Exps.hs | 119 +++++++++++++++++++-- src/Language/Futhark/Prop.hs | 14 +++ src/Language/Futhark/TypeChecker/Rank.hs | 4 +- src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++--- 6 files changed, 204 insertions(+), 35 deletions(-) diff --git a/src/Futhark/IR/Syntax/Core.hs b/src/Futhark/IR/Syntax/Core.hs index 227c25b23b..982fadcdec 100644 --- a/src/Futhark/IR/Syntax/Core.hs +++ b/src/Futhark/IR/Syntax/Core.hs @@ -15,6 +15,7 @@ module Futhark.IR.Syntax.Core ShapeBase (..), Shape, stripDims, + takeDims, Ext (..), ExtSize, ExtShape, @@ -128,6 +129,11 @@ instance Monoid (ShapeBase d) where stripDims :: Int -> ShapeBase d -> ShapeBase d stripDims n (Shape dims) = Shape $ drop n dims +-- | @takeDims n shape@ takes the outer @n@ dimensions from +-- @shape@. If @shape@ has m <= n dimensions, it returns $shape$. +takeDims :: Int -> ShapeBase d -> ShapeBase d +takeDims n (Shape dims) = Shape $ take n dims + -- | The size of an array as a list of subexpressions. If a variable, -- that variable must be in scope where this array is used. type Shape = ShapeBase SubExp diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 82cc845d69..d5894bec66 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -14,6 +14,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Debug.Trace import Futhark.IR.Pretty () import Futhark.MonadFreshNames import Futhark.Util (mapAccumLM, nubOrd) @@ -905,7 +906,7 @@ defuncApplyArg :: (Exp, StaticVal) -> (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, am), arg), _) = do (arg', arg_sv) <- defuncExp arg let env' = alwaysMatchPatSV pat arg_sv dims = mempty @@ -955,20 +956,29 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _ fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype + traceM $ + unlines + [ "sv", + show sv, + "ret sv", + show $ autoMapSV (autoMap am) sv + ] + pure - ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, - sv + ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, + autoMapSV (autoMap am) sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, mempty, arg')] callret - pure (apply_e, sv) + apply_e = mkApply f' [(argext, am, arg')] callret + -- pure (apply_e, autoMapSV (autoRep am) sv) + pure (apply_e, autoMapSV (autoMap am) sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ @@ -984,6 +994,11 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e +autoMapSV :: Shape Size -> StaticVal -> StaticVal +autoMapSV shape (Dynamic t) = + Dynamic $ arrayOfWithAliases (diet t) shape t +autoMapSV _ sv = sv + defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) @@ -999,10 +1014,39 @@ defuncApply f args appres loc = do _ -> do let fname = liftedName 0 f (argtypes, _) = unfoldFunType $ typeOf f - fmap (first $ updateReturn appres) $ - foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes + (app, app_sv) <- + fmap (first $ updateReturn appres) $ + foldM (defuncApplyArg fname) (f', f_sv) $ + NE.zip args $ + NE.tails argtypes + + let (p_ts, _) = unfoldFunType $ typeOf f + arg_ts = typeOf . snd <$> args + -- am_dims = zipWith typeShapePrefix (NE.toList arg_ts) p_ts + -- ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) am_dims + ams = NE.toList $ autoMap . snd . fst <$> args + ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) ams + traceM $ + unlines + [ "## defuncApply", + "## f", + prettyString f, + "## args", + prettyString $ snd <$> args, + "## appres", + show appres, + "## app", + prettyString app, + "## app_sv", + show app_sv, + "## f type", + prettyString $ typeOf f, + "## arg types", + prettyString $ (typeOf . snd) <$> args, + "## ret_am", + prettyString ret_am + ] + pure (app, autoMapSV ret_am $ app_sv) where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index ec4adece0b..1b96819dc8 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -8,13 +8,15 @@ module Futhark.Internalise.Exps (transformProg) where import Control.Monad import Control.Monad.Reader import Data.Bifunctor +import Data.Either import Data.Foldable (toList) -import Data.List (elemIndex, find, intercalate, intersperse, transpose) +import Data.List (elemIndex, find, intercalate, intersperse, maximumBy, transpose, zip4) import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Set qualified as S import Data.Text qualified as T +import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings @@ -346,12 +348,15 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = let subst = map (,E.ExpSubst (E.sizeFromInteger 0 mempty)) ext et' = E.applySubst (`lookup` subst) et internaliseExp desc (E.Hole (Info et') loc) - (FunctionName qfname, args) -> do + (FunctionName qfname, argsam) -> do -- Argument evaluation is outermost-in so that any existential sizes -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e arg_desc = nameToString fname ++ "_arg" + args = map (\(a, b, _) -> (a, b)) argsam + ams = map (\(_, _, c) -> c) argsam + res_t = et -- Some functions are magical (overloaded) and we handle that here. case () of @@ -388,8 +393,16 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) - funcall desc qfname args' loc + traceM $ + unlines + [ "## qfname", + prettyString qfname + ] + -- args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) + -- funcall desc qfname args' loc + + withAutoMap_ ams arg_desc res_t args $ \args' -> + funcall desc qfname (concat args') loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body internaliseAppExp _ _ (E.LetFun ofname _ _ _) = @@ -890,6 +903,98 @@ internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) +withAutoMap_ :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([[SubExp]] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] +withAutoMap_ ams arg_desc res_t args_e innerM = + withAutoMap ams arg_desc res_t args_e $ \args_stms -> do + let (args, stms) = unzip args_stms + mapM_ addStms $ reverse stms + innerM args + +withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] +withAutoMap ams arg_desc res_t args_e innerM = do + (args, stms) <- + foldM + ( \(args, stms) arg -> do + (arg', stms') <- inScopeOf (reverse stms) $ collectStms $ internaliseArg arg_desc arg + pure (arg' : args, stms' : stms) + ) + (mempty, mempty) + (reverse args_e) + argts <- inScopeOf (reverse stms) $ (mapM . mapM) subExpType args + expand args stms argts ams (maximum ds) + where + stripAutoMapDims i am = + am {autoMap = E.Shape $ drop i $ E.shapeDims $ autoMap am} + autoMapRank = E.shapeRank . autoMap + max_am = maximumBy (\x y -> E.shapeRank x `compare` E.shapeRank y) $ fmap autoMap ams + inner_t = E.stripArray (E.shapeRank max_am) res_t + ds = map autoMapRank ams + mkLambdaParams level (ses, ts, stm, d) + | d == level = + Left + <$> zipWithM + ( \se t -> do + let t' = I.stripArray 1 t + p <- newParam "x" t' + addStms stm + pure ((se, p), t') + ) + ses + ts + | otherwise = pure $ Right $ zip ses ts + + expand args stms argts ams' level + | level <= 0 = innerM $ zip args stms + | otherwise = do + let ds' = map autoMapRank ams' + arg_params <- mapM (mkLambdaParams level) $ zip4 args argts stms ds' + let argts' = map (either (map snd) (map snd)) arg_params + (ams'', stms') = + unzip $ + zipWith + ( \am stm -> + if autoMapRank am == level + then (stripAutoMapDims 1 am, mempty) + else (am, stm) + ) + ams' + stms + args' = map (either (map (I.Var . I.paramName . snd . fst)) (map fst)) arg_params + (map_ses, params) = unzip $ (concatMap . map) fst $ lefts arg_params + + ((ses, ses_ts), lam_stms) <- collectStms $ localScope (scopeOfLParams params) $ do + ses <- expand args' stms' argts' ams'' (level - 1) + ses_ts <- internaliseLambdaReturnType (E.toRes Nonunique inner_t) =<< mapM subExpType ses + pure (ses, ses_ts) + + case map_ses of + [] -> pure mempty + (map_se : _) -> do + outer_shape <- I.takeDims 1 . I.arrayShape <$> subExpType map_se + let I.Shape [outer_shape_se] = outer_shape + map_args <- forM map_ses $ \se -> do + se_t <- subExpType se + se_name <- letExp "map_arg" =<< toExp se + letExp "reshaped" $ + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter outer_shape 1 $ I.arrayShape se_t) + se_name + + letValExp' "automap" + . Op + . Screma outer_shape_se map_args + . mapSOAC + =<< mkLambda + params + ( ensureResultShape + (ErrorMsg [ErrorString "AutoMap: unexpected lambda result size"]) + mempty + ses_ts + =<< (addStms lam_stms >> pure (subExpsRes ses)) + ) + generateCond :: E.Pat StructType -> [I.SubExp] -> @@ -1477,14 +1582,14 @@ data Function | FunctionHole SrcLoc deriving (Show) -findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)]) +findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName, AutoMap)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) - | E.Hole (Info _) loc <- f = + | E.Hole (Info t) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info (argext, _), e) = (e, argext) + onArg (Info (argext, am), e) = (e, argext, am) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 419e50ba6b..6c69e1ef3d 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -33,6 +33,8 @@ module Language.Futhark.Prop stripExp, similarExps, frameOf, + shapePrefix, + typeShapePrefix, -- * Queries on patterns and params patIdents, @@ -1443,6 +1445,18 @@ frameOf (AppExp (Apply _ args _) _) = ((\(_, am) -> autoFrame am) . unInfo . fst) $ NE.last args frameOf _ = mempty +-- | @s1 `shapePrefix` s2@ assumes @s1 = prefix <> s2@ and +-- returns @prefix@. +shapePrefix :: Shape dim -> Shape dim -> Shape dim +shapePrefix (Shape ss1) (Shape ss2) = + Shape $ take (length ss1 - length ss2) ss1 + +typeShapePrefix :: TypeBase dim as1 -> TypeBase dim as2 -> Shape dim +typeShapePrefix (Array _ s _) Scalar {} = s +typeShapePrefix (Array _ s1 _) (Array _ s2 _) = + s1 `shapePrefix` s2 +typeShapePrefix _ _ = mempty + -- | An identifier with type- and aliasing information. type Ident = IdentBase Info VName diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 23d295f8ee..14215a7338 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -144,14 +144,14 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_glpk vns counter cs tyVars = do - traceM $ unlines ["## rankAnalysis prog", prettyString prog] + -- traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- if use_glpk then snd <$> (unsafePerformIO $ glpk prog) else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + -- traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) let initEnv = SubstEnv { envTyVars = tyVars, diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 52933dd1a3..d33a2ae17f 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1112,17 +1112,17 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + -- traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" vns <- gets termNameSource let use_glpk = True - traceM $ - unlines - [ "## cts:", - unlines $ map prettyString cts - ] + -- traceM $ + -- unlines + -- [ "## cts:", + -- unlines $ map prettyString cts + -- ] case rankAnalysis use_glpk vns counter cts tyvars of Nothing -> error "" @@ -1131,16 +1131,16 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' - traceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList) solution - ] + -- traceM $ + -- unlines + -- [ "## constraints:", + -- unlines $ map prettyString cts', + -- "## tyvars:", + -- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + -- "## solution:", + -- let p (v, t) = prettyNameString v <> " => " <> prettyString t + -- in either T.unpack (unlines . map p . M.toList) solution + -- ] pure (solution, params', retdecl', body') From 3e091b70b8588d8e389d3ac1b67ca3bdf35fb62e Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 18 Feb 2024 07:31:47 -0800 Subject: [PATCH 070/296] Oops. --- src/Language/Futhark/TypeChecker/Rank.hs | 4 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++++++++++----------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 14215a7338..23d295f8ee 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -144,14 +144,14 @@ mkLinearProg counter cs tyVars = rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_glpk vns counter cs tyVars = do - -- traceM $ unlines ["## rankAnalysis prog", prettyString prog] + traceM $ unlines ["## rankAnalysis prog", prettyString prog] rank_map <- if use_glpk then snd <$> (unsafePerformIO $ glpk prog) else do (_size, ranks) <- branchAndBound lp pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - -- traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) let initEnv = SubstEnv { envTyVars = tyVars, diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d33a2ae17f..52933dd1a3 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1112,17 +1112,17 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - -- traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" vns <- gets termNameSource let use_glpk = True - -- traceM $ - -- unlines - -- [ "## cts:", - -- unlines $ map prettyString cts - -- ] + traceM $ + unlines + [ "## cts:", + unlines $ map prettyString cts + ] case rankAnalysis use_glpk vns counter cts tyvars of Nothing -> error "" @@ -1131,16 +1131,16 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' - -- traceM $ - -- unlines - -- [ "## constraints:", - -- unlines $ map prettyString cts', - -- "## tyvars:", - -- unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - -- "## solution:", - -- let p (v, t) = prettyNameString v <> " => " <> prettyString t - -- in either T.unpack (unlines . map p . M.toList) solution - -- ] + traceM $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList) solution + ] pure (solution, params', retdecl', body') From 89939231df90c2b3e6208b64586432c7071636fe Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 18 Feb 2024 07:59:06 -0800 Subject: [PATCH 071/296] Basic map-only AUTOMAP seems to work now. --- src/Futhark/Internalise/Defunctionalise.hs | 28 ++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index d5894bec66..249616ef67 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -958,15 +958,23 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a traceM $ unlines - [ "sv", + [ "##defuncApplyArg LambdaSV", + "## fname", + fname_s, + "## f'", + prettyString f', + "## arg", + prettyString arg, + "## sv", show sv, - "ret sv", + "## ret sv", show $ autoMapSV (autoMap am) sv ] pure ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, autoMapSV (autoMap am) sv + -- sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially @@ -977,8 +985,20 @@ defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] apply_e = mkApply f' [(argext, am, arg')] callret - -- pure (apply_e, autoMapSV (autoRep am) sv) + traceM $ + unlines + [ "##defuncApplyArg DynamicFun", + "## f'", + prettyString f', + "## arg", + prettyString arg, + "## sv", + show sv, + "## ret sv", + show $ autoMapSV (autoMap am) sv + ] pure (apply_e, autoMapSV (autoMap am) sv) +-- pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ @@ -1046,7 +1066,7 @@ defuncApply f args appres loc = do "## ret_am", prettyString ret_am ] - pure (app, autoMapSV ret_am $ app_sv) + pure (app, app_sv) where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. From 4c8a1248a282a53a229b5f6a10e8d8b0247bbebb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:28:57 +0100 Subject: [PATCH 072/296] Please shut up. --- src/Futhark/Solve/GLPK.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index 7b27408a27..fe7ac5d129 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -43,5 +43,7 @@ glpk :: F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) glpk lp = do - (_, mres) <- glpSolveVars mipDefaults $ linearProgToGLPK lp + (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres + where + opts = mipDefaults {msgLev = MsgOff} From 6407ff3ee9e8604364b0c939ed779be26b1695f1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:40:49 +0100 Subject: [PATCH 073/296] Handle automapped operand. --- src/Language/Futhark/TypeChecker/Terms.hs | 38 +++-------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 28e2e9f2e5..1f12e844d4 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -987,39 +987,11 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do } pure (tp1, tp2'', argext, ext, am) -checkApply loc fname tfun@(Scalar TypeVar {}) arg = do - tv <- newTypeVar loc "b" - unify (mkUsage loc "use as function") tfun $ - Scalar (Arrow mempty Unnamed Observe (typeOf arg) $ RetType [] $ paramToRes tv) - tfun' <- normType tfun - checkApply loc fname tfun' arg -checkApply loc (fname, prev_applied) ftype argexp = do - let fname' = maybe "expression" (dquotes . pretty) fname - - typeError loc mempty $ - if prev_applied == 0 - then - "Cannot apply" - <+> fname' - <+> "as function, as it has type:" - indent 2 (pretty ftype) - else - "Cannot apply" - <+> fname' - <+> "to argument #" - <> pretty (prev_applied + 1) - <+> dquotes (shorten $ group $ pretty argexp) - <> "," - "as" - <+> fname' - <+> "only takes" - <+> pretty prev_applied - <+> arguments - <> "." - where - arguments - | prev_applied == 1 = "argument" - | otherwise = "arguments" +checkApply loc fname (Array _ _ t) arg = + -- This implies the function is the result of an automap. + checkApply loc fname (Scalar t) arg +checkApply _ _ _ _ = + error "checkApply: impossible case" -- | Type-check a single expression in isolation. This expression may -- turn out to be polymorphic, in which case the list of type From 9f1093241a154bb2d6d1cf1d243bbce41b1ef503 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:45:10 +0100 Subject: [PATCH 074/296] Put these adjacent. --- src/Language/Futhark/TypeChecker/Terms.hs | 36 +++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 1f12e844d4..d8766a8b2b 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -466,6 +466,24 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc +checkExp (AppExp (Apply fe args loc) _) = do + fe' <- checkExp fe + args' <- mapM (checkExp . snd) args + t <- expType fe' + let fname = + case fe' of + Var v _ _ -> Just v + _ -> Nothing + ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' + + pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts + where + onArg fname (i, all_exts, t) arg' = do + (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' + pure + ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), + (Info (argext, am), arg') + ) checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do ftype <- lookupVar oploc op op_t e1' <- checkExp e1 @@ -532,24 +550,6 @@ checkExp (Negate arg loc) = do checkExp (Not arg loc) = do arg' <- checkExp arg pure $ Not arg' loc -checkExp (AppExp (Apply fe args loc) _) = do - fe' <- checkExp fe - args' <- mapM (checkExp . snd) args - t <- expType fe' - let fname = - case fe' of - Var v _ _ -> Just v - _ -> Nothing - ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' - - pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts - where - onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' - pure - ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), - (Info (argext, am), arg') - ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e From 2ea1c6af433b2a53510aa9a4cab2100bf9720b44 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 18 Feb 2024 17:48:23 +0100 Subject: [PATCH 075/296] Add frame to binop result. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d8766a8b2b..7c97574b15 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -503,7 +503,7 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do (e2', Info (p2_ext, am2)) loc ) - (Info (AppRes rt' retext)) + (Info (AppRes (arrayOf (autoFrame am2) rt') retext)) checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' From b644e3f6e7aba99a247d6b6c4004266596b7a608 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 18 Feb 2024 11:50:51 -0800 Subject: [PATCH 076/296] AUTOMAP `OpSection` support. --- src/Futhark/Internalise/Exps.hs | 27 ++++-- src/Futhark/Internalise/FullNormalise.hs | 8 +- src/Futhark/Internalise/Monomorphise.hs | 22 ++--- src/Language/Futhark/Interpreter.hs | 4 +- src/Language/Futhark/Syntax.hs | 4 +- src/Language/Futhark/Traversals.hs | 8 +- src/Language/Futhark/TypeChecker/Terms.hs | 14 ++-- src/Language/Futhark/TypeChecker/Terms2.hs | 96 +++++++++++++++++----- 8 files changed, 124 insertions(+), 59 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1b96819dc8..2e3cabd4f5 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -381,17 +381,28 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - let prepareArg (arg, _) = - (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg - internalise =<< mapM prepareArg args + -- let prepareArg (arg, _) = + -- (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg + -- internalise =<< mapM prepareArg args + -- + withAutoMap_ ams arg_desc res_t args $ \args' -> do + let prepareArg (arg, _, am) arg' = + (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') + internalise $ zipWith prepareArg argsam args' | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do - let tag ses = [(se, I.Observe) | se <- ses] - args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) - let args'' = concatMap tag args' - letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) + -- Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do + -- let tag ses = [(se, I.Observe) | se <- ses] + -- args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) + -- let args'' = concatMap tag args' + -- letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) + -- + Just (rettype, _) <- M.lookup fname I.builtInFunctions -> + withAutoMap_ ams arg_desc res_t args $ \args' -> do + let tag ses = [(se, I.Observe) | se <- ses] + let args'' = concatMap tag args' + letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do traceM $ unlines diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 91e16a9a53..6ee354ea4f 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -203,13 +203,13 @@ getOrdering final (Lambda params body mte ret loc) = do nameExp final $ Lambda params body' mte ret loc getOrdering _ (OpSection qn ty loc) = pure $ Var qn ty loc -getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do +getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, xam), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do x <- getOrdering False e yn <- newNameFromString "y" let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ + mkApply (Var op ty mempty) [(xext, xam, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -217,12 +217,12 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (I | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing -getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext)) (Info (RetType dims ret)) loc) = do +getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, yam)) (Info (RetType dims ret)) loc) = do xn <- newNameFromString "x" y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, yam, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index ee352f67c6..aada3924c0 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -664,27 +664,27 @@ transformExp (Lambda params e0 decl tp loc) = do transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do - let (Info (xp, xtype, xargext), Info (yp, ytype)) = arg + let (Info (xp, xtype, xargext, xam), Info (yp, ytype)) = arg e' <- transformExp e desugarBinOpSection fname (Just e') Nothing t - (xp, xtype, xargext) - (yp, ytype, Nothing) + (xp, xtype, xargext, xam) + (yp, ytype, Nothing, mempty) (rettype, retext) loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do - let (Info (xp, xtype), Info (yp, ytype, yargext)) = arg + let (Info (xp, xtype), Info (yp, ytype, yargext, yam)) = arg e' <- transformExp e desugarBinOpSection fname Nothing (Just e') t - (xp, xtype, Nothing) - (yp, ytype, yargext) + (xp, xtype, Nothing, mempty) + (yp, ytype, yargext, yam) (rettype, []) loc transformExp (ProjectSection fields (Info t) loc) = do @@ -735,12 +735,12 @@ desugarBinOpSection :: Maybe Exp -> Maybe Exp -> StructType -> - (PName, ParamType, Maybe VName) -> - (PName, ParamType, Maybe VName) -> + (PName, ParamType, Maybe VName, AutoMap) -> + (PName, ParamType, Maybe VName, AutoMap) -> (ResRetType, [VName]) -> SrcLoc -> MonoM Exp -desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do +desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, yext, yam) (RetType dims rettype, retext) loc = do t' <- transformType t op <- transformFName loc fname $ toStruct t (v1, wrap_left, e1, p1) <- makeVarParam e_left =<< transformType xtype @@ -748,7 +748,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( let apply_left = mkApply op - [(xext, mempty, e1)] + [(xext, xam, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -757,7 +757,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) ( rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, mempty, e2)] + mkApply apply_left [(yext, yam, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 01e4bddb21..527c734cbf 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -1047,11 +1047,11 @@ eval env (Lambda ps body _ (Info (RetType _ rt)) _) = evalFunction env [] ps body rt eval env (OpSection qv (Info t) _) = evalTermVar env qv $ toStruct t -eval env (OpSectionLeft qv _ e (Info (_, _, argext), _) (Info (RetType _ t), _) loc) = do +eval env (OpSectionLeft qv _ e (Info (_, _, argext, _), _) (Info (RetType _ t), _) loc) = do v <- evalArg env e argext f <- evalTermVar env qv (toStruct t) apply loc env f v -eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext)) (Info (RetType _ t)) loc) = do +eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, _)) (Info (RetType _ t)) loc) = do y <- evalArg env e argext pure $ ValueFun $ \x -> do diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index bd2133f017..ef7afa4d30 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -848,7 +848,7 @@ data ExpBase f vn (QualName vn) (f StructType) (ExpBase f vn) - (f (PName, ParamType, Maybe VName), f (PName, ParamType)) + (f (PName, ParamType, Maybe VName, AutoMap), f (PName, ParamType)) (f ResRetType, f [VName]) SrcLoc | -- | @+2@; first type is operand, second is result. @@ -856,7 +856,7 @@ data ExpBase f vn (QualName vn) (f StructType) (ExpBase f vn) - (f (PName, ParamType), f (PName, ParamType, Maybe VName)) + (f (PName, ParamType), f (PName, ParamType, Maybe VName, AutoMap)) (f ResRetType) SrcLoc | -- | Field projection as a section: @(.x.y.z)@. diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index 798edae981..fc20935c24 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -184,25 +184,25 @@ instance ASTMappable (ExpBase Info VName) where <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> pure loc - astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext), Info (pb, t1b)) (ret, retext) loc) = + astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc) = OpSectionLeft <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) - <$> (Info <$> ((pa,,) <$> mapOnParamType tv t1a <*> pure argext)) + <$> (Info <$> ((pa,,,) <$> mapOnParamType tv t1a <*> pure argext <*> pure am)) <*> (Info <$> ((pb,) <$> mapOnParamType tv t1b)) ) <*> ((,) <$> traverse (mapOnResRetType tv) ret <*> pure retext) <*> pure loc - astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext)) t2 loc) = + astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc) = OpSectionRight <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) <$> (Info <$> ((pa,) <$> mapOnParamType tv t1a)) - <*> (Info <$> ((pb,,) <$> mapOnParamType tv t1b <*> pure argext)) + <*> (Info <$> ((pb,,,) <$> mapOnParamType tv t1b <*> pure argext <*> pure am)) ) <*> traverse (mapOnResRetType tv) t2 <*> pure loc diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 7c97574b15..d3b59381a8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -725,16 +725,16 @@ checkExp (OpSection op (Info op_t) loc) = do checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext, _) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of - (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 rettype)) -> + (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ OpSectionLeft op (Info ftype) e' - (Info (m1, toParam d1 t1, argext), Info (m2, toParam d2 t2)) - (Info rettype, Info retext) + (Info (m1, toParam d1 t1, argext, am), Info (m2, toParam d2 t2)) + (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am) rt2, Info retext) loc _ -> typeError loc mempty $ @@ -744,7 +744,7 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _, _) <- + (t2', arrow', argext, _, am) <- checkApply loc (Just op, 1) @@ -757,8 +757,8 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext)) - (Info $ RetType dims2' ret') + (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am)) + (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am) ret') loc _ -> error $ "OpSectionRight: impossible type\n" <> prettyString arrow' _ -> diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 52933dd1a3..335c532cea 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -599,8 +599,59 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) -checkApply loc _ ftype fframe arg = do +checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> [(Shape Size, Type)] -> TermM (StructType, [AutoMap]) +checkApply loc fname (fframe, ftype) args = do + ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args + rt' <- asStructType loc rt + pure (rt', argts) + where + -- pure (asStructType loc rt, argts) + + onArg (i, f_f, f_t) (argframe, argtype) = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) + pure + ( (i + 1, autoFrame am, rt), + am + ) + +checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) +checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do + (a, b) <- split $ stripFrame fframe ftype + r <- newSVar loc "R" + m <- newSVar loc "M" + let unit_info = Info $ Scalar $ Prim Bool + r_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] m) unit_info mempty + lhs = arrayOf (toShape (SVar r) <> (toSComp <$> argframe)) argtype + rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a + ctAM r m + ctEq lhs rhs + pure + ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, + AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} + ) + where + stripFrame :: Shape Size -> Type -> Type + stripFrame frame (Array u ds t) = + let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) + in case mnew_shape of + Nothing -> Scalar t + Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t + stripFrame _ t = t + toSComp (Var (QualName [] x) _ _) = SVar x + toSComp _ = error "" + toShape = Shape . pure + split (Scalar (Arrow _ _ _ a (RetType _ b))) = + pure (a, b `setUniqueness` NoUniqueness) + split ftype' = do + a <- newType loc "arg" + b <- newTyVar loc "res" + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b + pure (a, tyVarType b) + +-- To be removed (probably) +checkApply_ :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) +checkApply_ loc _ ftype fframe arg = do (a, b) <- split $ stripFrame fframe ftype r <- newSVar loc "R" m <- newSVar loc "M" @@ -828,7 +879,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do onArg (i, f_t, f_f) (_, arg) = do arg' <- checkExp arg - (rt, am) <- checkApply loc (fname, i) f_t f_f arg' + (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' pure ( (i + 1, rt, autoFrame am), (Info (Nothing, am), arg') @@ -838,52 +889,55 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 e2' <- checkExp e2 - - (rt1, am1) <- checkApply loc (Just op, 0) (toType ftype) mempty e1' - (rt2, am2) <- checkApply loc (Just op, 1) rt1 (autoFrame am1) e2' - rt2' <- asStructType loc rt2 + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, toType ftype) + [(frameOf e1', toType $ typeOf e1'), (frameOf e2', toType $ typeOf e2')] + let [am1, am2] = ams pure $ AppExp (BinOp (op, oploc) (Info ftype) (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) - (Info (AppRes rt2' [])) + (Info (AppRes rt [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - void $ checkApply loc (Just op, 0) (toType optype) mempty e' + t2 <- newType loc "t" + t2' <- asStructType loc t2 let t1 = typeOf e' - t2 <- newType loc "t2" - rt <- newType loc "rt" - ctEq (toType optype) $ toType $ foldFunType [toParam Observe t1, t2] $ RetType [] $ rt `setUniqueness` Nonunique + f1 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(f1, toType t1), (mempty, t2)] pure $ OpSectionLeft op (Info optype) e' - ( Info (Unnamed, toParam Observe t1, Nothing), - Info (Unnamed, t2) + ( Info (Unnamed, toParam Observe t1, Nothing, ams !! 0), -- fix + Info (Unnamed, toParam Observe t2') ) - (Info (RetType [] rt), Info []) + (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) loc --- checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e t1 <- newType loc "t" + t1' <- asStructType loc t1 let t2 = typeOf e' - rt <- newType loc "rt" - ctEq (toType optype) $ toType $ foldFunType [t1, toParam Observe t2] $ RetType [] $ rt `setUniqueness` Nonunique + f2 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(mempty, t1), (f2, toType t2)] pure $ OpSectionRight op (Info optype) e' -- Dummy types. - ( Info (Unnamed, toParam Observe t1), - Info (Unnamed, toParam Observe t2, Nothing) + ( Info (Unnamed, toParam Observe t1'), + Info (Unnamed, toParam Observe t2, Nothing, ams !! 1) -- fix ) - (Info $ RetType [] rt) + (Info $ RetType [] (rt `setUniqueness` Nonunique)) loc -- checkExp (ProjectSection fields NoInfo loc) = do From 14c5544d5816b208453483c98d1d8ed6f91c33fe Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 19 Feb 2024 16:13:06 +0100 Subject: [PATCH 077/296] Flail at the constraint solver. Now explicitly returns variables to be generalised. --- .../Futhark/TypeChecker/Constraints.hs | 23 +++++++++++++------ src/Language/Futhark/TypeChecker/Terms.hs | 15 ++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 20 ++++++++++------ 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 2548c74008..ddaedd4025 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -117,19 +117,28 @@ substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = substTyVars m (Array u shape elemt) = arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt --- | A solution maps a type variable to its substitution. This substitution is complete, in the sense there are no right-hand sides that contain a type variable. +-- | A solution maps a type variable to its substitution. This +-- substitution is complete, in the sense there are no right-hand +-- sides that contain a type variable. type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -solution :: SolverState -> Solution +solution :: SolverState -> ([VName], Solution) solution s = - M.mapMaybe mkSubst $ - solverTyVars s + ( mapMaybe unconstrained $ M.toList $ solverTyVars s, + M.mapMaybe mkSubst $ solverTyVars s + ) where - mkSubst (TyVarSol _lvl t) = Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t - mkSubst (TyVarLink v') = mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (TyVarSol _lvl t) = + Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t + mkSubst (TyVarLink v') = + Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ + mkSubst =<< M.lookup v' (solverTyVars s) mkSubst (TyVarUnsol _ (TyVarPrim pts)) = Just $ Left pts mkSubst _ = Nothing + unconstrained (v, TyVarUnsol _ TyVarFree) = Just v + unconstrained _ = Nothing + newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) @@ -213,7 +222,7 @@ solveCt ct = Nothing -> bad Just eqs -> mapM_ solveCt' eqs -solve :: Constraints -> TyVars -> Either T.Text Solution +solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d3b59381a8..eacb5b6964 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1015,7 +1015,7 @@ checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp tysubsts $ do + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' when (hasBinding e'') $ @@ -1216,14 +1216,6 @@ localChecks = void . check e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI (-x) t) (-x) t (loc1 <> loc2) _ -> error "Inferred type of int literal is not a number" - check e@(AppExp (BinOp (QualName [] v, _) _ (x, _) _ loc) _) - | baseName v == "==", - Array {} <- typeOf x, - baseTag v <= maxIntrinsicTag = do - warn loc $ - textwrap - "Comparing arrays with \"==\" is deprecated and will stop working in a future revision of the language." - recurse e check e = recurse e recurse = astMap identityMapper {mapOnExp = check} @@ -1597,9 +1589,10 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right tysubsts -> runTermTypeM checkExp tysubsts $ do + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained (tparams', params'', retdecl'', RetType dims rettype', body'') <- - checkBinding (fname, retdecl', tparams, params', body', loc) + checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) -- Since this is a top-level function, we also resolve overloaded -- types, using either defaults or complaining about ambiguities. diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 335c532cea..594e9af0ac 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -46,6 +46,7 @@ import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor +import Data.Bitraversable import Data.Char (isAscii) import Data.List qualified as L import Data.List.NonEmpty qualified as NE @@ -396,7 +397,7 @@ lookupVar loc qn@(QualName qs name) = do outer_env <- asks termOuterEnv asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do - argtype <- newType loc "t" + argtype <- newTypeOverloaded loc "t" anyPrimType pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts @@ -915,7 +916,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do op (Info optype) e' - ( Info (Unnamed, toParam Observe t1, Nothing, ams !! 0), -- fix + ( Info (Unnamed, toParam Observe t1, Nothing, head ams), -- fix Info (Unnamed, toParam Observe t2') ) (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) @@ -1149,7 +1150,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), + ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1183,7 +1184,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do Just (cts', tyvars', vns', counter') -> do modify $ \s -> s {termCounter = counter', termNameSource = vns'} - solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts' tyvars' + solution <- + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ + solve cts' tyvars' traceM $ unlines @@ -1193,17 +1196,20 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList) solution + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] pure (solution, params', retdecl', body') checkSingleExp :: ExpBase NoInfo VName -> - TypeM (Either T.Text (M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - solution <- traverse (M.traverseWithKey (doDefaults mempty)) $ solve cts tyvars + solution <- + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ + solve cts tyvars pure (solution, e') From 52beb8c7c22ac3fd69d33fa906d349db631de5c7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 19 Feb 2024 16:23:53 +0100 Subject: [PATCH 078/296] Print this too. --- src/Language/Futhark/TypeChecker.hs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 879a27afc4..9d07bf5b53 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -26,6 +26,7 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S +import Debug.Trace import Futhark.FreshNames hiding (newName) import Futhark.Util.Pretty hiding (space) import Language.Futhark @@ -704,6 +705,9 @@ checkValBind vb = do _ -> pure () let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc + + traceM $ unlines ["Inferred:", prettyString vb'] + pure ( mempty { envVtable = From 8f4cb684e9f3ac8d53135cbf8bd4b316a8c1486a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 11:38:07 -0800 Subject: [PATCH 079/296] Add ambiguity checking. --- src/Language/Futhark/TypeChecker/Rank.hs | 89 +++++++++++++++--------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 23d295f8ee..bff7a08b84 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -92,7 +92,7 @@ binVar sv = do modify $ \s -> s { rankBinVars = M.insert sv bv $ rankBinVars s, - rankConstraints = rankConstraints s ++ [bin bv] + rankConstraints = rankConstraints s ++ [bin bv, var bv ~<=~ var sv] } pure bv Just bv -> pure bv @@ -141,36 +141,66 @@ mkLinearProg counter cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP +ambigCheckLinearProg :: LinearProg -> (Double, Map VName Int) -> LinearProg +ambigCheckLinearProg prog (opt, ranks) = + prog + { constraints = + constraints prog + ++ [ lsum (var <$> M.keys one_bins) + ~-~ lsum (var <$> M.keys zero_bins) + ~<=~ constant (fromIntegral $ length one_bins) + ~-~ constant 1, + objective prog ~==~ constant opt + ] + } + where + -- We really need to track which variables are binary in the LinearProg + is_bin_var = ("b_" `L.isPrefixOf`) . baseString + one_bins = M.filterWithKey (\k v -> is_bin_var k && v == 1) ranks + zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks + lsum = foldr (~+~) (constant 0) + rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) rankAnalysis use_glpk vns counter cs tyVars = do traceM $ unlines ["## rankAnalysis prog", prettyString prog] - rank_map <- - if use_glpk - then snd <$> (unsafePerformIO $ glpk prog) - else do - (_size, ranks) <- branchAndBound lp - pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) - let initEnv = - SubstEnv - { envTyVars = tyVars, - envRanks = rank_map - } - - initState = - SubstState - { substTyVars = mempty, - substNewVars = mempty, - substNameSource = vns, - substCounter = counter, - substNewCts = mempty - } - (cs', state') = - runSubstM initEnv initState $ - substRanks $ - filter (not . isCtAM) cs - pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') + -- rank_map <- + -- if use_glpk + -- then snd <$> (unsafePerformIO $ glpk prog) + -- else do + -- (_size, ranks) <- branchAndBound lp + -- pure $ (fromJust . (ranks V.!?)) <$> inv_var_map + (size, rank_map) <- unsafePerformIO $ glpk prog + case unsafePerformIO $ glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) of + Just (size', rank_map') -> do + traceM $ + unlines $ + "## rank map" + : map prettyString (M.toList rank_map) + ++ "## ambig rank map" + : map prettyString (M.toList rank_map') + error "ambiguous" + Nothing -> do + traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) + let initEnv = + SubstEnv + { envTyVars = tyVars, + envRanks = rank_map + } + + initState = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNameSource = vns, + substCounter = counter, + substNewCts = mempty + } + (cs', state') = + runSubstM initEnv initState $ + substRanks $ + filter (not . isCtAM) cs + pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -189,11 +219,6 @@ rankAnalysis use_glpk vns counter cs tyVars = do (lp, var_map) = linearProgToLP prog inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] - rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" - vname_to_pulp_var = M.mapWithKey (\k _ -> map rm_subscript $ show $ prettyName k) inv_var_map - pulp_var_to_vname = - M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList vname_to_pulp_var] - newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) From 6d91328e18839cfd68d5010071613ce86b5f6c48 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 11:41:23 -0800 Subject: [PATCH 080/296] Forgot the source. --- src/Language/Futhark/TypeChecker/Rank.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index bff7a08b84..d67032446e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -146,6 +146,7 @@ ambigCheckLinearProg prog (opt, ranks) = prog { constraints = constraints prog + -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html ++ [ lsum (var <$> M.keys one_bins) ~-~ lsum (var <$> M.keys zero_bins) ~<=~ constant (fromIntegral $ length one_bins) From 2c19d786bc18216657b5aec61ff70b9d74a84966 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 19 Feb 2024 21:09:33 +0100 Subject: [PATCH 081/296] Add rep shapes here. --- src/Language/Futhark/TypeChecker/Terms.hs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index eacb5b6964..6b6bdcb16e 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -492,7 +492,7 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) (arrayOf (autoFrame am1) rt) e2' pure $ AppExp @@ -987,9 +987,15 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do } pure (tp1, tp2'', argext, ext, am) -checkApply loc fname (Array _ _ t) arg = +checkApply loc fname (Array _ shape t) arg = do -- This implies the function is the result of an automap. - checkApply loc fname (Scalar t) arg + (t1, rt, argext, retext, am) <- checkApply loc fname (Scalar t) arg + let am' = + am + { autoRep = shape <> autoRep am, + autoFrame = shape <> autoFrame am + } + pure (t1, rt, argext, retext, am') checkApply _ _ _ _ = error "checkApply: impossible case" From 8dd13d4632b02dc3157ed9af9290ed3a04304d1d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 19:45:42 -0800 Subject: [PATCH 082/296] Make the design of `Rank.hs` less dumb. --- src/Language/Futhark/TypeChecker/Rank.hs | 174 ++++++++++----------- src/Language/Futhark/TypeChecker/Terms2.hs | 47 +++--- 2 files changed, 105 insertions(+), 116 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index d67032446e..eca628fdc1 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -6,20 +6,14 @@ import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M import Data.Maybe -import Data.Vector.Unboxed qualified as V import Debug.Trace -import Futhark.FreshNames qualified as FreshNames -import Futhark.MonadFreshNames hiding (newName) -import Futhark.Solve.BranchAndBound import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP -import Futhark.Solve.Simplex import Language.Futhark hiding (ScalarType) import Language.Futhark.TypeChecker.Constraints -import Language.Futhark.TypeChecker.Monad (mkTypeVarName) +import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe -import System.Process type LSum = LP.LSum VName Double @@ -61,7 +55,7 @@ instance Distribute Type where distribute = distributeOne where distributeOne (Array _ s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s $ tr) + Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s tr) distributeOne t = t instance Distribute Ct where @@ -83,9 +77,9 @@ incCounter = do put s {rankCounter = rankCounter s + 1} pure $ rankCounter s -binVar :: VName -> RankM (VName) +binVar :: VName -> RankM VName binVar sv = do - mbv <- (M.!? sv) <$> gets rankBinVars + mbv <- gets ((M.!? sv) . rankBinVars) case mbv of Nothing -> do bv <- VName ("b_" <> baseName sv) <$> incCounter @@ -112,7 +106,7 @@ addCt (CtAM r m) = do addConstraints $ oneIsZero (b_r, r) (b_m, m) addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () -addTyVarInfo tv (_, TyVarFree) = pure () +addTyVarInfo _ (_, TyVarFree) = pure () addTyVarInfo tv (_, TyVarPrim _) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarRecord _) = @@ -120,8 +114,8 @@ addTyVarInfo tv (_, TyVarRecord _) = addTyVarInfo tv (_, TyVarSum _) = addConstraint $ rank tv ~==~ constant 0 -mkLinearProg :: Int -> [Ct] -> TyVars -> LinearProg -mkLinearProg counter cs tyVars = +mkLinearProg :: [Ct] -> TyVars -> LinearProg +mkLinearProg cs tyVars = LP.LinearProg { optType = Minimize, objective = @@ -133,7 +127,7 @@ mkLinearProg counter cs tyVars = initState = RankState { rankBinVars = mempty, - rankCounter = counter, + rankCounter = 0, rankConstraints = mempty } buildLP = do @@ -161,50 +155,41 @@ ambigCheckLinearProg prog (opt, ranks) = zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks lsum = foldr (~+~) (constant 0) -rankAnalysis :: Bool -> VNameSource -> Int -> [Ct] -> TyVars -> Maybe ([Ct], TyVars, VNameSource, Int) -rankAnalysis _ vns counter [] tyVars = Just ([], tyVars, vns, counter) -rankAnalysis use_glpk vns counter cs tyVars = do - traceM $ unlines ["## rankAnalysis prog", prettyString prog] - -- rank_map <- - -- if use_glpk - -- then snd <$> (unsafePerformIO $ glpk prog) - -- else do - -- (_size, ranks) <- branchAndBound lp - -- pure $ (fromJust . (ranks V.!?)) <$> inv_var_map - (size, rank_map) <- unsafePerformIO $ glpk prog - case unsafePerformIO $ glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) of - Just (size', rank_map') -> do - traceM $ - unlines $ - "## rank map" - : map prettyString (M.toList rank_map) - ++ "## ambig rank map" - : map prettyString (M.toList rank_map') - error "ambiguous" - Nothing -> do - traceM $ unlines $ "## rank map" : map prettyString (M.toList rank_map) - let initEnv = - SubstEnv - { envTyVars = tyVars, - envRanks = rank_map - } - - initState = - SubstState - { substTyVars = mempty, - substNewVars = mempty, - substNameSource = vns, - substCounter = counter, - substNewCts = mempty - } - (cs', state') = - runSubstM initEnv initState $ - substRanks $ - filter (not . isCtAM) cs - pure (cs' <> substNewCts state', substTyVars state' <> tyVars, substNameSource state', substCounter state') +checkProg :: (MonadTypeChecker m, Located loc) => loc -> LinearProg -> m (Map VName Int) +checkProg loc prog = do + traceM $ + unlines + [ "## checkProg", + prettyString prog + ] + case run_glpk prog of + Nothing -> typeError loc mempty "Rank ILP cannot be solved." + Just sol@(_size, rank_map) -> + case check_ambig sol of + Nothing -> do + traceM $ + unlines $ + "## rank map" : map prettyString (M.toList rank_map) + pure rank_map + Just (_, rank_map') -> do + traceM $ + unlines $ + "## rank map" + : map prettyString (M.toList rank_map) + ++ "## ambig rank map" + : map prettyString (M.toList rank_map') + typeError loc mempty "Rank ILP is ambiguous." + where + run_glpk = unsafePerformIO . glpk + check_ambig (size, rank_map) = + run_glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) + +rankAnalysis :: (MonadTypeChecker m, Located loc) => loc -> [Ct] -> TyVars -> m ([Ct], TyVars) +rankAnalysis _ [] tyVars = pure ([], tyVars) +rankAnalysis loc cs tyVars = do + checkProg loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + >>= substRankInfo cs tyVars where - isCtAM (CtAM {}) = True - isCtAM _ = False splitFuncs ( CtEq (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) @@ -215,17 +200,43 @@ rankAnalysis use_glpk vns counter cs tyVars = do t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] - cs' = foldMap splitFuncs cs - prog = mkLinearProg counter cs' tyVars - (lp, var_map) = linearProgToLP prog - inv_var_map = M.fromListWith (error "oh no!") [(v, k) | (k, v) <- M.toList var_map] -newtype SubstM a = SubstM (StateT SubstState (Reader SubstEnv) a) - deriving (Functor, Applicative, Monad, MonadState SubstState, MonadReader SubstEnv) +substRankInfo :: (MonadTypeChecker m) => [Ct] -> TyVars -> Map VName Int -> m ([Ct], TyVars) +substRankInfo cs tyVars rankmap = do + (cs', new_cs, new_tyVars) <- + runSubstT tyVars rankmap $ + substRanks $ + filter (not . isCtAM) cs + pure (cs' <> new_cs, new_tyVars <> tyVars) + where + isCtAM (CtAM {}) = True + isCtAM _ = False + +runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) +runSubstT tyVars rankmap (SubstT m) = do + let env = + SubstEnv + { envTyVars = tyVars, + envRanks = rankmap + } -runSubstM :: SubstEnv -> SubstState -> SubstM a -> (a, SubstState) -runSubstM initEnv initState (SubstM m) = - runReader (runStateT m initState) initEnv + s = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNewCts = mempty + } + (a, s') <- runReaderT (runStateT m s) env + pure (a, substNewCts s', substTyVars s') + +newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) + deriving + ( Functor, + Applicative, + Monad, + MonadState SubstState, + MonadReader SubstEnv + ) data SubstEnv = SubstEnv { envTyVars :: TyVars, @@ -235,21 +246,15 @@ data SubstEnv = SubstEnv data SubstState = SubstState { substTyVars :: TyVars, substNewVars :: Map TyVar TyVar, - substNameSource :: VNameSource, - substCounter :: !Int, substNewCts :: [Ct] } -substIncCounter :: SubstM Int -substIncCounter = do - s <- get - put s {substCounter = substCounter s + 1} - pure $ substCounter s +instance MonadTrans SubstT where + lift = SubstT . lift . lift -newTyVar :: TyVar -> SubstM TyVar +newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do - i <- substIncCounter - t' <- newID $ mkTypeVarName (baseName t) i + t' <- lift $ newTypeName (baseName t) shape <- rankToShape t modify $ \s -> s @@ -262,22 +267,15 @@ newTyVar t = do ] } pure t' - where - newID x = do - s <- get - let (v', src') = FreshNames.newName (substNameSource s) $ VName x 0 - put $ s {substNameSource = src'} - pure v' -rankToShape :: VName -> SubstM (Shape SComp) +rankToShape :: (Monad m) => VName -> SubstT m (Shape SComp) rankToShape x = do rs <- asks envRanks pure $ Shape $ replicate (fromJust $ rs M.!? x) SDim -addRankInfo :: TyVar -> SubstM () +addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks - -- unless (fromMaybe (error $ prettyString t) (rs M.!? t) == 0) $ do unless (fromMaybe 0 (rs M.!? t) == 0) $ do new_vars <- gets substNewVars maybe new_var (const $ pure ()) $ new_vars M.!? t @@ -290,7 +288,7 @@ addRankInfo t = do modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} class SubstRanks a where - substRanks :: a -> SubstM a + substRanks :: (MonadTypeChecker m) => a -> SubstT m a instance (SubstRanks a) => SubstRanks [a] where substRanks = mapM substRanks @@ -298,11 +296,11 @@ instance (SubstRanks a) => SubstRanks [a] where instance SubstRanks (Shape SComp) where substRanks = foldM (\s d -> (s <>) <$> instDim d) mempty where - instDim (SDim) = pure $ Shape $ pure SDim + instDim SDim = pure $ Shape $ pure SDim instDim (SVar x) = rankToShape x instance SubstRanks (TypeBase SComp u) where - substRanks t@(Scalar (TypeVar u (QualName [] x) [])) = + substRanks t@(Scalar (TypeVar _ (QualName [] x) [])) = addRankInfo x >> pure t substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do ta' <- substRanks ta diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 594e9af0ac..e5a25df58c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1163,44 +1163,35 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do cts <- gets termConstraints - counter <- gets termCounter - tyvars <- gets termTyVars traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - vns <- gets termNameSource - - let use_glpk = True - traceM $ unlines [ "## cts:", unlines $ map prettyString cts ] - case rankAnalysis use_glpk vns counter cts tyvars of - Nothing -> error "" - Just (cts', tyvars', vns', counter') -> do - modify $ \s -> s {termCounter = counter', termNameSource = vns'} - - solution <- - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ - solve cts' tyvars' - - traceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - - pure (solution, params', retdecl', body') + (cts', tyvars') <- rankAnalysis loc cts tyvars + + solution <- + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ + solve cts' tyvars' + + traceM $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution + ] + + pure (solution, params', retdecl', body') checkSingleExp :: ExpBase NoInfo VName -> From af0b0bb45e9be1c75d4bb5f20295a688fdda0d35 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 20:50:35 -0800 Subject: [PATCH 083/296] Pass a list of possible solutions around. --- src/Language/Futhark/TypeChecker/Rank.hs | 59 +++++++-------- src/Language/Futhark/TypeChecker/Terms.hs | 88 ++++++++++++---------- src/Language/Futhark/TypeChecker/Terms2.hs | 42 ++++++----- 3 files changed, 99 insertions(+), 90 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index eca628fdc1..729f58848d 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -135,7 +135,7 @@ mkLinearProg cs tyVars = mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP -ambigCheckLinearProg :: LinearProg -> (Double, Map VName Int) -> LinearProg +ambigCheckLinearProg :: LinearProg -> (Int, Map VName Int) -> LinearProg ambigCheckLinearProg prog (opt, ranks) = prog { constraints = @@ -145,7 +145,7 @@ ambigCheckLinearProg prog (opt, ranks) = ~-~ lsum (var <$> M.keys zero_bins) ~<=~ constant (fromIntegral $ length one_bins) ~-~ constant 1, - objective prog ~==~ constant opt + objective prog ~==~ constant (fromIntegral opt) ] } where @@ -155,40 +155,39 @@ ambigCheckLinearProg prog (opt, ranks) = zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks lsum = foldr (~+~) (constant 0) -checkProg :: (MonadTypeChecker m, Located loc) => loc -> LinearProg -> m (Map VName Int) -checkProg loc prog = do +-- We should probably cap the iteration on this +enumerateRankSols :: LinearProg -> [Map VName Int] +enumerateRankSols prog = + takeSolns $ + iterate next_sol $ + (prog,) <$> run_glpk prog + where + run_glpk = unsafePerformIO . glpk + next_sol m = do + (prog', sol') <- m + let prog'' = ambigCheckLinearProg prog' sol' + sol'' <- run_glpk prog'' + pure (prog'', sol'') + takeSolns [] = [] + takeSolns (Nothing : _) = [] + takeSolns (Just (_, (_, r)) : xs) = r : takeSolns xs + +solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] +solveRankILP loc prog = do traceM $ unlines - [ "## checkProg", + [ "## solveRankILP", prettyString prog ] - case run_glpk prog of - Nothing -> typeError loc mempty "Rank ILP cannot be solved." - Just sol@(_size, rank_map) -> - case check_ambig sol of - Nothing -> do - traceM $ - unlines $ - "## rank map" : map prettyString (M.toList rank_map) - pure rank_map - Just (_, rank_map') -> do - traceM $ - unlines $ - "## rank map" - : map prettyString (M.toList rank_map) - ++ "## ambig rank map" - : map prettyString (M.toList rank_map') - typeError loc mempty "Rank ILP is ambiguous." - where - run_glpk = unsafePerformIO . glpk - check_ambig (size, rank_map) = - run_glpk $ ambigCheckLinearProg prog (fromIntegral size, rank_map) + case enumerateRankSols prog of + [] -> typeError loc mempty "Rank ILP cannot be solved." + rs -> pure rs -rankAnalysis :: (MonadTypeChecker m, Located loc) => loc -> [Ct] -> TyVars -> m ([Ct], TyVars) -rankAnalysis _ [] tyVars = pure ([], tyVars) +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> m [([Ct], TyVars)] +rankAnalysis _ [] tyVars = pure [([], tyVars)] rankAnalysis loc cs tyVars = do - checkProg loc (mkLinearProg (foldMap splitFuncs cs) tyVars) - >>= substRankInfo cs tyVars + solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + >>= mapM (substRankInfo cs tyVars) where splitFuncs ( CtEq diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 6b6bdcb16e..f80407b378 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1590,44 +1590,52 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = do - (maybe_tysubsts, params', retdecl', body') <- + (maybe_tysubstss, params', retdecl', body') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - - case maybe_tysubsts of - Left err -> typeError loc mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do - let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained - (tparams', params'', retdecl'', RetType dims rettype', body'') <- - checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params'' - - -- Then replace all inferred types in the body and parameters. - body''' <- updateTypes body'' - params''' <- updateTypes params'' - retdecl''' <- traverse updateTypes retdecl'' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body''' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params'' - localChecks body''' - - let ((body'''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params''', - body''', - RetType dims rettype'', - retdecl''', - loc - ) - - mapM_ throwError errors - - pure (tparams', params''', retdecl''', updated_ret, body'''') + case maybe_tysubstss of + [] -> error "impossible" + [maybe_tysubsts] -> doChecks (maybe_tysubsts, params', retdecl', body') + _ -> typeError loc mempty "Rank ILP is ambiguous" + where + -- TODO: Print out the possibilities. (And also potentially eliminate + --- some of the possibilities to disambiguate). + + doChecks (maybe_tysubsts, params', retdecl', body') = + case maybe_tysubsts of + Left err -> typeError loc mempty $ pretty err + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- updateTypes body'' + params''' <- updateTypes params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e5a25df58c..467bde81bf 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1150,7 +1150,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), + ( [Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness))], [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1173,25 +1173,27 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map prettyString cts ] - (cts', tyvars') <- rankAnalysis loc cts tyvars - - solution <- - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ - solve cts' tyvars' - - traceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - - pure (solution, params', retdecl', body') + cts_tyvars' <- rankAnalysis loc cts tyvars + + solutions <- + forM cts_tyvars' $ + bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) + . uncurry solve + + forM (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> + traceM $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution + ] + + pure (solutions, params', retdecl', body') checkSingleExp :: ExpBase NoInfo VName -> From 012680d2cb46c96b33555ce5e3b073b643bff917 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 19 Feb 2024 20:58:55 -0800 Subject: [PATCH 084/296] Cap the number of solutions. --- src/Language/Futhark/TypeChecker/Rank.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 729f58848d..247f0555f7 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -155,12 +155,12 @@ ambigCheckLinearProg prog (opt, ranks) = zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks lsum = foldr (~+~) (constant 0) --- We should probably cap the iteration on this enumerateRankSols :: LinearProg -> [Map VName Int] enumerateRankSols prog = - takeSolns $ - iterate next_sol $ - (prog,) <$> run_glpk prog + take 5 $ + takeSolns $ + iterate next_sol $ + (prog,) <$> run_glpk prog where run_glpk = unsafePerformIO . glpk next_sol m = do From 9ab28c07195c8e1ebc275a8769f1e172e6d72e4d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 05:37:42 -0800 Subject: [PATCH 085/296] Make big `M` actually big(ish). --- src/Futhark/Solve/LP.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 7623033e7c..ca2fbe73f0 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -201,7 +201,7 @@ linearProgToPulp prog = rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" bigM :: (Num a) => a -bigM = 10 ^ 3 +bigM = 10 ^ 6 oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = From 7616bcce1174243c3560e27f439394cbddf9220c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 05:42:00 -0800 Subject: [PATCH 086/296] AUTOMAP for short-circuiting operators. --- src/Futhark/Internalise/Exps.hs | 44 +++++++++++---------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 2e3cabd4f5..f8d20ee005 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -363,28 +363,26 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = () -- Short-circuiting operators are magical. | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "&&", - [(x, _), (y, _)] <- args -> - internaliseExp desc $ - E.AppExp - (E.If x y (E.Literal (E.BoolValue False) mempty) mempty) - (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) + baseString (qualLeaf qfname) == "&&" -> + withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do + letValExp' desc + =<< eIf + (addStms x_stms >> pure (BasicOp $ SubExp x)) + (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) + (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue False]) | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "||", - [(x, _), (y, _)] <- args -> - internaliseExp desc $ - E.AppExp - (E.If x (E.Literal (E.BoolValue True) mempty) y mempty) - (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) + baseString (qualLeaf qfname) == "||" -> + withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do + letValExp' desc + =<< eIf + (addStms x_stms >> pure (BasicOp $ SubExp x)) + (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue True]) + (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) -- Overloaded and intrinsic functions never take array -- arguments (except equality, but those cannot be -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - -- let prepareArg (arg, _) = - -- (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg - -- internalise =<< mapM prepareArg args - -- withAutoMap_ ams arg_desc res_t args $ \args' -> do let prepareArg (arg, _, am) arg' = (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') @@ -392,26 +390,12 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - -- Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do - -- let tag ses = [(se, I.Observe) | se <- ses] - -- args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) - -- let args'' = concatMap tag args' - -- letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) - -- Just (rettype, _) <- M.lookup fname I.builtInFunctions -> withAutoMap_ ams arg_desc res_t args $ \args' -> do let tag ses = [(se, I.Observe) | se <- ses] let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - traceM $ - unlines - [ "## qfname", - prettyString qfname - ] - -- args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) - -- funcall desc qfname args' loc - withAutoMap_ ams arg_desc res_t args $ \args' -> funcall desc qfname (concat args') loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = From 870f97f4acafc16212fc1b2974b5c3c36bc1b4fa Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 06:08:56 -0800 Subject: [PATCH 087/296] Use `Int`s instead of `Double`s and print out rank maps. --- src/Language/Futhark/TypeChecker/Rank.hs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 247f0555f7..5989c27593 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -15,11 +15,11 @@ import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe -type LSum = LP.LSum VName Double +type LSum = LP.LSum VName Int -type Constraint = LP.Constraint VName Double +type Constraint = LP.Constraint VName Int -type LinearProg = LP.LinearProg VName Double +type LinearProg = LP.LinearProg VName Int type ScalarType = ScalarTypeBase SComp NoUniqueness @@ -181,7 +181,14 @@ solveRankILP loc prog = do ] case enumerateRankSols prog of [] -> typeError loc mempty "Rank ILP cannot be solved." - rs -> pure rs + rs -> do + traceM "## rank maps" + forM (zip [0 :: Int ..] rs) $ \(i, r) -> + traceM $ + unlines $ + "\n## rank map " <> prettyString i + : map prettyString (M.toList r) + pure rs rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> m [([Ct], TyVars)] rankAnalysis _ [] tyVars = pure [([], tyVars)] From 87da34a300b0acf49e951a75f7215d72f56a7b35 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 06:14:59 -0800 Subject: [PATCH 088/296] Apparently powers of 2 are better. 10^6 also somehow gives incorrect results. --- src/Futhark/Solve/LP.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index ca2fbe73f0..044f6efe63 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -201,7 +201,7 @@ linearProgToPulp prog = rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" bigM :: (Num a) => a -bigM = 10 ^ 6 +bigM = 2 ^ 10 oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = From 043ff227d4692f18b09465763c12e21521a68ecb Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 20 Feb 2024 08:11:27 -0800 Subject: [PATCH 089/296] Support auto replicates in internalization. --- src/Futhark/Internalise/Exps.hs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index f8d20ee005..6f25174a91 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -938,8 +938,24 @@ withAutoMap ams arg_desc res_t args_e innerM = do ts | otherwise = pure $ Right $ zip ses ts + internaliseShape = + fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims + + addReplicates = + zipWithM + ( \am arg -> do + rep_shape <- + internaliseShape $ + autoRep am `E.shapePrefix` autoFrame am + if I.shapeRank rep_shape > 0 + then concat <$> mapM (letValExp' "autoRep" . BasicOp . Replicate rep_shape) arg + else pure arg + ) + expand args stms argts ams' level - | level <= 0 = innerM $ zip args stms + | level <= 0 = do + args' <- addReplicates ams' args + innerM $ zip args' stms | otherwise = do let ds' = map autoMapRank ams' arg_params <- mapM (mkLambdaParams level) $ zip4 args argts stms ds' From 674bf01aa19e9727fea49b2adb51046f8b41770c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 21 Feb 2024 15:26:54 +0100 Subject: [PATCH 090/296] Fiddle with Ident type checking. --- src/Language/Futhark/TypeChecker/Terms.hs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index f80407b378..45f6c820fb 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -343,6 +343,10 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp +checkIdent :: Ident StructType -> TermTypeM (Ident StructType) +checkIdent (Ident v t loc) = + Ident v <$> traverse (replaceTyVars loc) t <*> pure loc + checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc @@ -601,18 +605,20 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _ ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do + dest' <- checkIdent dest + src' <- checkIdent src slice' <- checkSlice slice - (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' - unify (mkUsage loc "type of target array") t $ unInfo $ identType src + (t, _) <- newArrayType (mkUsage src' "type of source array") "src" $ sliceDims slice' + unify (mkUsage loc "type of target array") t $ unInfo $ identType src' (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve - bindingIdent dest $ do + bindingIdent dest' $ do body' <- checkExp body - (body_t, ext) <- unscopeType loc [identName dest] =<< expTypeFully body' - pure $ AppExp (LetWith dest src slice' ve' body' loc) (Info $ AppRes body_t ext) + (body_t, ext) <- unscopeType loc [identName dest'] =<< expTypeFully body' + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) checkExp (Update src slice ve loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' src) "src" $ sliceDims slice' From a10bd3d30615b30fa0f31f8b77f3ba26e18d5107 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 21 Feb 2024 19:52:30 +0100 Subject: [PATCH 091/296] Fix type checking of LetWith. --- src/Language/Futhark/TypeChecker/Terms.hs | 9 +++------ src/Language/Futhark/TypeChecker/Terms/Monad.hs | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 45f6c820fb..9c8ff1ddcb 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -343,10 +343,6 @@ unscopeType :: unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp -checkIdent :: Ident StructType -> TermTypeM (Ident StructType) -checkIdent (Ident v t loc) = - Ident v <$> traverse (replaceTyVars loc) t <*> pure loc - checkExp :: Exp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc @@ -605,8 +601,9 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, _, e) body loc) _ ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do - dest' <- checkIdent dest - src' <- checkIdent src + src_t <- lookupVar loc (qualName (identName src)) (unInfo (identType src)) + let src' = src {identType = Info src_t} + dest' = dest {identType = Info src_t} slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage src' "type of source array") "src" $ sliceDims slice' unify (mkUsage loc "type of target array") t $ unInfo $ identType src' diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 8bdbb81daf..094f4f8b62 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -358,7 +358,7 @@ replaceTyVars loc orig_t = do f (Scalar (TypeVar u (QualName [] v) [])) | Just t <- M.lookup v tyvars = - fst <$> allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" (second (const u) t) + fst <$> allDimsFreshInType (mkUsage loc "replaceTyVars") Nonrigid "dv" (second (const u) t) | otherwise = pure $ Scalar (TypeVar u (QualName [] v) []) f (Scalar (TypeVar u qn targs)) = From 2bb4085cf4ea602ebce6b9d739c64839a68f3682 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 21 Feb 2024 21:43:56 +0100 Subject: [PATCH 092/296] OptionPricing now type checks. --- src/Language/Futhark/TypeChecker/Terms.hs | 4 +++- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 9c8ff1ddcb..f4fc11ebdd 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1333,7 +1333,8 @@ checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = verifyFunctionParams (Just fname) params'' (tparams', params''', rettype') <- - letGeneralise (baseName fname) loc tparams params'' =<< unscopeUnknown rettype + letGeneralise (baseName fname) loc tparams params'' + =<< unscopeUnknown rettype when ( null params @@ -1449,6 +1450,7 @@ closeOverTypes defname defloc tparams paramts ret substs = do case M.lookup v substs of Just (_, UnknownSize {}) -> Just v _ -> Nothing + pure ( tparams ++ more_tparams, injectExt (nubOrd $ retext ++ mapMaybe mkExt (S.toList $ fvVars $ freeInType ret)) ret diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index b1a2f59a8d..dbd1d019f0 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -127,7 +127,7 @@ checkPat' _ (Id name (Info t) loc) NoneInferred = do pure $ Id name (Info t') loc checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do t <- replaceTyVars loc t1 - unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) + unify (mkUsage loc "id") (toStruct t) (toStruct t2) pure $ Id name (Info t) loc checkPat' _ (Wildcard (Info t) loc) NoneInferred = do t' <- replaceTyVars loc t diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 467bde81bf..6c0a55ab60 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -488,16 +488,20 @@ checkPat' (RecordPat fs loc) NoneInferred = checkPat' (PatAscription p t loc) maybe_outer_t = do (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp t + -- Uniqueness kung fu to make the Monoid(mempty) instance give what + -- we expect. We should perhaps stop being so implicit. + st' <- asStructType loc $ toType $ resToParam st + case maybe_outer_t of Ascribed outer_t -> do - ctEq (toType st) (toType outer_t) + ctEq (toType st') (toType outer_t) PatAscription - <$> checkPat' p (Ascribed (resToParam st)) + <$> checkPat' p (Ascribed st') <*> pure t' <*> pure loc NoneInferred -> PatAscription - <$> checkPat' p (Ascribed (resToParam st)) + <$> checkPat' p (Ascribed st') <*> pure t' <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do @@ -1180,7 +1184,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) . uncurry solve - forM (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> + forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> traceM $ unlines [ "## constraints:", From 4af653a645d18612fd89867933ab23fe0592da15 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 21 Feb 2024 23:34:23 -0800 Subject: [PATCH 093/296] - Introduce rank representation for `AutoMap` annotations to carry rank information across type checking phases. - Fix AUTOMAP in `Terms.hs`. --- src/Language/Futhark/Syntax.hs | 27 ++++++-- src/Language/Futhark/TypeChecker/Rank.hs | 42 +++++++++++-- src/Language/Futhark/TypeChecker/Terms.hs | 72 +++++++++++----------- src/Language/Futhark/TypeChecker/Terms2.hs | 6 +- 4 files changed, 99 insertions(+), 48 deletions(-) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index ef7afa4d30..fae5741c8e 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -24,6 +24,9 @@ module Language.Futhark.Syntax shapeRank, stripDims, AutoMap (..), + autoRepRank, + autoMapRank, + autoFrameRank, TypeBase (..), TypeArg (..), SizeExp (..), @@ -261,13 +264,27 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap = AutoMap - { autoRep :: Shape Size, - autoMap :: Shape Size, - autoFrame :: Shape Size - } +data AutoMap + = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } + | AutoMapRank Int Int Int deriving (Eq, Show, Ord) +autoRepRank :: AutoMap -> Int +autoRepRank (AutoMapRank r _ _) = r +autoRepRank _ = 0 + +autoMapRank :: AutoMap -> Int +autoMapRank (AutoMapRank _ m _) = m +autoMapRank _ = 0 + +autoFrameRank :: AutoMap -> Int +autoFrameRank (AutoMapRank _ _ f) = f +autoFrameRank _ = 0 + instance Semigroup AutoMap where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 5989c27593..554eba77dd 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -2,6 +2,8 @@ module Language.Futhark.TypeChecker.Rank (rankAnalysis) where import Control.Monad.Reader import Control.Monad.State +import Data.Bifunctor +import Data.Functor.Identity import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M @@ -11,6 +13,7 @@ import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Language.Futhark hiding (ScalarType) +import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe @@ -190,11 +193,13 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> m [([Ct], TyVars)] -rankAnalysis _ [] tyVars = pure [([], tyVars)] -rankAnalysis loc cs tyVars = do - solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) - >>= mapM (substRankInfo cs tyVars) +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] +rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] +rankAnalysis loc cs tyVars body = do + rank_maps <- solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps + let bodys = map (flip updAM body) rank_maps + pure $ zip cts_tyvars' bodys where splitFuncs ( CtEq @@ -321,3 +326,30 @@ instance SubstRanks (TypeBase SComp u) where instance SubstRanks Ct where substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" + +updAM :: Map VName Int -> Exp -> Exp +updAM rank_map e = + case e of + AppExp (Apply f args loc) res -> + let f' = updAM rank_map f + args' = + fmap + ( bimap + (fmap $ bimap id upd) + (updAM rank_map) + ) + args + in AppExp (Apply f' args' loc) res + AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> + AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res + _ -> runIdentity $ astMap m e + where + dimToRank (Var (QualName [] x) _ _) = rank_map M.! x + dimToRank e = error $ prettyString e + shapeToRank = sum . fmap dimToRank + upd (AutoMap r m f) = + AutoMapRank (shapeToRank r) (shapeToRank m) (shapeToRank f) + m = + identityMapper + { mapOnExp = pure . updAM rank_map + } diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index f4fc11ebdd..8c48591696 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -466,33 +466,33 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp (AppExp (Apply fe args loc) _) = do +checkExp e@(AppExp (Apply fe args loc) _) = do fe' <- checkExp fe + let ams = fmap (snd . unInfo . fst) args args' <- mapM (checkExp . snd) args t <- expType fe' let fname = case fe' of Var v _ _ -> Just v _ -> Nothing - ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' + ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) (NE.zip args' ams) pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts where - onArg fname (i, all_exts, t) arg' = do - (_, rt, argext, exts, am) <- checkApply loc (fname, i) t arg' + onArg fname (i, all_exts, t) (arg', am) = do + (_, rt, argext, exts, am') <- checkApply loc (fname, i) t arg' am pure - ( (i + 1, all_exts <> exts, arrayOf (autoFrame am) rt), - (Info (argext, am), arg') + ( (i + 1, all_exts <> exts, rt), + (Info (argext, am'), arg') ) -checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do +checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, Info (_, xam)) (e2, Info (_, yam)) loc) _) = do ftype <- lookupVar oploc op op_t e1' <- checkExp e1 e2' <- checkExp e2 - -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' - (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) (arrayOf (autoFrame am1) rt) e2' + (_, rt, p1_ext, _, am1) <- checkApply loc (Just op, 0) ftype e1' xam + (_, rt', p2_ext, retext, am2) <- checkApply loc (Just op, 1) rt e2' yam pure $ AppExp @@ -503,7 +503,7 @@ checkExp (AppExp (BinOp (op, oploc) (Info op_t) (e1, _) (e2, _) loc) _) = do (e2', Info (p2_ext, am2)) loc ) - (Info (AppRes (arrayOf (autoFrame am2) rt') retext)) + (Info (AppRes rt' retext)) checkExp (Project k e _ loc) = do e' <- checkExp e t <- expType e' @@ -725,10 +725,10 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do checkExp (OpSection op (Info op_t) loc) = do ftype <- lookupVar loc op op_t pure $ OpSection op (Info ftype) loc -checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do +checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' + (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' am case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ @@ -742,7 +742,7 @@ checkExp (OpSectionLeft op (Info op_t) e _ _ loc) = do _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype -checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do +checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e case ftype of @@ -753,6 +753,7 @@ checkExp (OpSectionRight op (Info op_t) e _ _ loc) = do (Just op, 1) (Scalar $ Arrow mempty m2 d2 t2 $ RetType [] $ Scalar $ Arrow Nonunique m1 d1 t1 $ RetType dims2 ret) e' + am case arrow' of Scalar (Arrow _ _ _ t1' (RetType dims2' ret')) -> pure $ @@ -930,19 +931,25 @@ stripToMatch paramt (Array _ (Shape (d : ds)) argt) = first (Shape [d] <>) $ stripToMatch paramt $ arrayOf (Shape ds) (Scalar argt) stripToMatch _ argt = (mempty, argt) +splitArrayAt :: Int -> StructType -> (Shape Size, StructType) +splitArrayAt x t = + (Shape $ take x $ shapeDims $ arrayShape t, stripArray x t) + checkApply :: SrcLoc -> ApplyOp -> StructType -> Exp -> + AutoMap -> TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) -checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do +checkApply loc (fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do - (am_map_shape, argtype_automap) <- - stripToMatch <$> normTypeFully tp1 <*> normTypeFully argtype + (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype + (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 + let (am_frame_shape, argtype_automap) = splitArrayAt (autoFrameRank am) argtype_with_frame - unify (mkUsage argexp "use as function argument") tp1 argtype_automap + unify (mkUsage argexp "use as function argument") tp1_with_frame argtype_with_frame -- Perform substitutions of instantiated variables in the types. (tp2', ext) <- instantiateDimsInReturnType loc fname =<< normTypeFully tp2 @@ -986,21 +993,16 @@ checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do AutoMap { autoMap = am_map_shape, autoRep = mempty, - autoFrame = am_map_shape + autoFrame = am_map_shape <> am_frame_shape } - pure (tp1, tp2'', argext, ext, am) -checkApply loc fname (Array _ shape t) arg = do - -- This implies the function is the result of an automap. - (t1, rt, argext, retext, am) <- checkApply loc fname (Scalar t) arg - let am' = - am - { autoRep = shape <> autoRep am, - autoFrame = shape <> autoFrame am - } - pure (t1, rt, argext, retext, am') -checkApply _ _ _ _ = - error "checkApply: impossible case" + pure (tp1, distributeFrame (autoMap am) tp2'', argext, ext, am) + where + distributeFrame frame (Scalar (Arrow u p d a (RetType ds b))) = + Scalar $ Arrow u p d (arrayOf frame a) (RetType ds (arrayOfWithAliases (uniqueness b) frame b)) + distributeFrame frame t = arrayOf frame t +checkApply _ _ _ _ _ = + error "checkApply: array" -- | Type-check a single expression in isolation. This expression may -- turn out to be polymorphic, in which case the list of type @@ -1595,11 +1597,11 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = do - (maybe_tysubstss, params', retdecl', body') <- + (maybe_tysubstss, params', retdecl', bodys') <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - case maybe_tysubstss of - [] -> error "impossible" - [maybe_tysubsts] -> doChecks (maybe_tysubsts, params', retdecl', body') + case (maybe_tysubstss, bodys') of + ([], _) -> error "impossible" + ([maybe_tysubsts], [body']) -> doChecks (maybe_tysubsts, params', retdecl', body') _ -> typeError loc mempty "Rank ILP is ambiguous" where -- TODO: Print out the possibilities. (And also potentially eliminate diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6c0a55ab60..6b1f9cb763 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1157,7 +1157,7 @@ checkValDef :: ( [Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness))], [Pat ParamType], Maybe (TypeExp Exp VName), - Exp + [Exp] ) checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bindParams tparams params $ \params' -> do @@ -1177,7 +1177,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map prettyString cts ] - cts_tyvars' <- rankAnalysis loc cts tyvars + (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' solutions <- forM cts_tyvars' $ @@ -1197,7 +1197,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] - pure (solutions, params', retdecl', body') + pure (solutions, params', retdecl', bodys') checkSingleExp :: ExpBase NoInfo VName -> From c336b16008f1fdba4ae61dba508ad0301db4c0f5 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 09:13:16 +0100 Subject: [PATCH 094/296] Do not rewrite automapped short-circuiting ops. --- src/Futhark/Internalise/FullNormalise.hs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 6ee354ea4f..509d41c7ca 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -299,16 +299,22 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do body' <- transformBody body nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do + -- Rewrite short-circuiting boolean operators on scalars to explicit + -- if-then-else. expr' <- case (isOr, isAnd) of - (True, _) -> do - el' <- naming "or_lhs" $ getOrdering True el - er' <- naming "or_rhs" $ transformBody er - pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) - (_, True) -> do - el' <- naming "and_lhs" $ getOrdering True el - er' <- naming "and_rhs" $ transformBody er - pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) - (False, False) -> do + (True, _) + | elam == mempty, + eram == mempty -> do + el' <- naming "or_lhs" $ getOrdering True el + er' <- naming "or_rhs" $ transformBody er + pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) + (_, True) + | elam == mempty, + eram == mempty -> do + el' <- naming "and_lhs" $ getOrdering True el + er' <- naming "and_rhs" $ transformBody er + pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) + _ -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT From ff353e57ba5924dffd1d6afe2451a6063db95dcf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 09:15:57 +0100 Subject: [PATCH 095/296] We do not need this. --- src/Futhark/Internalise/Exps.hs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 6f25174a91..6eca179e3c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -361,23 +361,6 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- Some functions are magical (overloaded) and we handle that here. case () of () - -- Short-circuiting operators are magical. - | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "&&" -> - withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do - letValExp' desc - =<< eIf - (addStms x_stms >> pure (BasicOp $ SubExp x)) - (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) - (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue False]) - | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - baseString (qualLeaf qfname) == "||" -> - withAutoMap ams arg_desc res_t args $ \[([x], x_stms), ([y], y_stms)] -> do - letValExp' desc - =<< eIf - (addStms x_stms >> pure (BasicOp $ SubExp x)) - (eBody [pure $ BasicOp $ SubExp $ Constant $ I.BoolValue True]) - (addStms y_stms >> eBody [pure $ BasicOp $ SubExp y]) -- Overloaded and intrinsic functions never take array -- arguments (except equality, but those cannot be -- existential), so we can safely ignore the existential From 986856d32bf2726c66b42a731d46bcd044155139 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 09:40:05 +0100 Subject: [PATCH 096/296] Try to handle logical operators. --- src/Futhark/Internalise/Exps.hs | 5 ++++- src/Futhark/Internalise/FullNormalise.hs | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 6eca179e3c..c17db733ba 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1680,12 +1680,15 @@ isOverloadedFunction qname desc loc = do handle name | Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = Just $ \[(x_t, [x']), (y_t, [y'])] -> - case (x_t, y_t) of + case (arrayElem x_t, arrayElem y_t) of (E.Scalar (E.Prim t1), E.Scalar (E.Prim t2)) -> internaliseBinOp loc desc bop x' y' t1 t2 _ -> error "Futhark.Internalise.internaliseExp: non-primitive type in BinOp." handle _ = Nothing + arrayElem (E.Array _ _ t) = E.Scalar t + arrayElem t = t + -- | Handle intrinsic functions. These are only allowed to be called -- in the prelude, and their internalisation may involve inspecting -- the AST. diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 509d41c7ca..17841c9f53 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -300,7 +300,8 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do -- Rewrite short-circuiting boolean operators on scalars to explicit - -- if-then-else. + -- if-then-else. Automapped cases are turned into applications of + -- intrinsic functions. expr' <- case (isOr, isAnd) of (True, _) | elam == mempty, @@ -308,18 +309,30 @@ getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info el' <- naming "or_lhs" $ getOrdering True el er' <- naming "or_rhs" $ transformBody er pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) + | otherwise -> do + el' <- naming "or_lhs" $ getOrdering False el + er' <- naming "or_rhs" $ getOrdering False er + pure $ mkApply orop [(elp, elam, el'), (erp, eram, er')] resT (_, True) | elam == mempty, eram == mempty -> do el' <- naming "and_lhs" $ getOrdering True el er' <- naming "and_rhs" $ transformBody er pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) + | otherwise -> do + el' <- naming "and_lhs" $ getOrdering False el + er' <- naming "and_rhs" $ getOrdering False er + pure $ mkApply andop [(elp, elam, el'), (erp, eram, er')] resT _ -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT nameExp final expr' where + bool = Scalar $ Prim Bool + opt = foldFunType [bool, bool] $ RetType [] bool + andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty + orop = Var (qualName (intrinsicVar "||")) (Info opt) mempty isOr = baseName (qualLeaf op) == "||" isAnd = baseName (qualLeaf op) == "&&" getOrdering final (AppExp (LetWith (Ident dest dty dloc) (Ident src sty sloc) slice e body loc) _) = do From 263ec737c8dbbac6a3d88e227592926d6d6a88f9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 00:47:16 -0800 Subject: [PATCH 097/296] Remove `autoMapRank`. --- src/Language/Futhark/Syntax.hs | 21 ++++++++------------- src/Language/Futhark/TypeChecker/Rank.hs | 7 ++++--- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index fae5741c8e..b39a82cd0b 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -264,26 +264,21 @@ stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -data AutoMap - = AutoMap - { autoRep :: Shape Size, - autoMap :: Shape Size, - autoFrame :: Shape Size - } - | AutoMapRank Int Int Int +data AutoMap = AutoMap + { autoRep :: Shape Size, + autoMap :: Shape Size, + autoFrame :: Shape Size + } deriving (Eq, Show, Ord) autoRepRank :: AutoMap -> Int -autoRepRank (AutoMapRank r _ _) = r -autoRepRank _ = 0 +autoRepRank = shapeRank . autoRep autoMapRank :: AutoMap -> Int -autoMapRank (AutoMapRank _ m _) = m -autoMapRank _ = 0 +autoMapRank = shapeRank . autoMap autoFrameRank :: AutoMap -> Int -autoFrameRank (AutoMapRank _ _ f) = f -autoFrameRank _ = 0 +autoFrameRank = shapeRank . autoFrame instance Semigroup AutoMap where (AutoMap r1 m1 f1) <> (AutoMap r2 m2 f2) = AutoMap (r1 <> r2) (m1 <> m2) (f1 <> f2) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 554eba77dd..1eef479dcf 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -344,11 +344,12 @@ updAM rank_map e = AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res _ -> runIdentity $ astMap m e where - dimToRank (Var (QualName [] x) _ _) = rank_map M.! x + dimToRank (Var (QualName [] x) _ _) = + replicate (rank_map M.! x) (TupLit mempty mempty) dimToRank e = error $ prettyString e - shapeToRank = sum . fmap dimToRank + shapeToRank = Shape . foldMap dimToRank upd (AutoMap r m f) = - AutoMapRank (shapeToRank r) (shapeToRank m) (shapeToRank f) + AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) m = identityMapper { mapOnExp = pure . updAM rank_map From 285604e1a1dca422fc7f20fc25e682a9a1c4ba4c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 14:28:23 +0100 Subject: [PATCH 098/296] Another hash. --- src/Language/Futhark/TypeChecker.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 6ce5977c9e..d6b5baf8fc 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -707,7 +707,7 @@ checkValBind vb = do let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - traceM $ unlines ["Inferred:", prettyString vb'] + traceM $ unlines ["# Inferred:", prettyString vb'] pure ( mempty From 2334dfd11e1209a4cc2d409a494a9c801d69f087 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 22 Feb 2024 14:28:56 +0100 Subject: [PATCH 099/296] Make uniqueness explicit. --- src/Language/Futhark/TypeChecker/Terms2.hs | 122 +++++++++++---------- 1 file changed, 65 insertions(+), 57 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6b1f9cb763..d9e5930b71 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -216,8 +216,8 @@ incCounter = do put s {termCounter = termCounter s + 1} pure $ termCounter s -tyVarType :: (Monoid u) => TyVar -> TypeBase dim u -tyVarType v = Scalar $ TypeVar mempty (qualName v) [] +tyVarType :: u -> TyVar -> TypeBase dim u +tyVarType u v = Scalar $ TypeVar u (qualName v) [] newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar newTyVarWith _loc desc info = do @@ -230,38 +230,46 @@ newTyVarWith _loc desc info = do newTyVar :: (Located loc) => loc -> Name -> TermM TyVar newTyVar loc desc = newTyVarWith loc desc TyVarFree -newType :: (Located loc, Monoid u) => loc -> Name -> TermM (TypeBase dim u) -newType loc desc = tyVarType <$> newTyVar loc desc +newType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) +newType loc desc u = tyVarType u <$> newTyVar loc desc -newTypeWithField :: (Monoid u) => SrcLoc -> Name -> Name -> Type -> TermM (TypeBase d u) +newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type newTypeWithField loc desc k t = - tyVarType <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) + tyVarType NoUniqueness <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) -newTypeWithConstr :: (Monoid u) => SrcLoc -> Name -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) -newTypeWithConstr loc desc k ts = - tyVarType <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') +newTypeWithConstr :: SrcLoc -> Name -> u -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) +newTypeWithConstr loc desc u k ts = + tyVarType u <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') where ts' = map (`setUniqueness` NoUniqueness) ts -newTypeOverloaded :: (Monoid u) => SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d u) +newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniqueness) newTypeOverloaded loc name pts = - tyVarType <$> newTyVarWith loc name (TyVarPrim pts) + tyVarType NoUniqueness <$> newTyVarWith loc name (TyVarPrim pts) newSVar :: (Located loc) => loc -> Name -> TermM SVar newSVar _loc desc = do i <- incCounter newID $ mkTypeVarName desc i -asStructType :: (Monoid u) => SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) +asStructType :: SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do t1' <- asStructType loc t1 t2' <- asStructType loc t2 pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' -asStructType loc t = do - t' <- newType loc "artificial" - ctEq (toType t' `setUniqueness` NoUniqueness) (t `setUniqueness` NoUniqueness) +asStructType loc (Scalar (Record fs)) = + Scalar . Record <$> traverse (asStructType loc) fs +asStructType loc (Scalar (Sum cs)) = + Scalar . Sum <$> traverse (mapM (asStructType loc)) cs +asStructType loc t@(Scalar (TypeVar u _ _)) = do + t' <- newType loc "artificial" u + ctEq (toType t') t + pure t' +asStructType loc t@(Array u _ _) = do + t' <- newType loc "artificial" u + ctEq (toType t') t pure t' addCt :: Ct -> TermM () @@ -370,7 +378,7 @@ instTypeScheme _qn loc tparams t = do case tparam of TypeParamType _ v _ -> do v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType v')) + pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType NoUniqueness v')) TypeParamDim {} -> pure Nothing let t' = applySubst (`lookup` substs) t @@ -398,11 +406,11 @@ lookupVar loc qn@(QualName qs name) = do asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newTypeOverloaded loc "t" anyPrimType - pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool + pure $ foldFunType [toParam Observe argtype, toParam Observe argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType pts' $ RetType [] $ toRes Nonunique rt' + pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -426,9 +434,9 @@ bind idents = localScope (`bindVars` idents) -- literals in patterns. patLitMkType :: PatLit -> SrcLoc -> TermM ParamType patLitMkType (PatLitInt _) loc = - newTypeOverloaded loc "t" anyNumberType + toParam Observe <$> newTypeOverloaded loc "t" anyNumberType patLitMkType (PatLitFloat _) loc = - newTypeOverloaded loc "t" anyFloatType + toParam Observe <$> newTypeOverloaded loc "t" anyFloatType patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v @@ -449,12 +457,12 @@ checkPat' (PatAttr attr p loc) t = checkPat' (Id name NoInfo loc) (Ascribed t) = pure $ Id name (Info t) loc checkPat' (Id name NoInfo loc) NoneInferred = do - t <- newType loc "t" + t <- newType loc "t" Observe pure $ Id name (Info t) loc checkPat' (Wildcard _ loc) (Ascribed t) = pure $ Wildcard (Info t) loc checkPat' (Wildcard NoInfo loc) NoneInferred = do - t <- newType loc "t" + t <- newType loc "t" Observe pure $ Wildcard (Info t) loc checkPat' (TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, @@ -463,7 +471,7 @@ checkPat' (TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t :: [ParamType] <- replicateM (length ps) (newType loc "t") + ps_t <- replicateM (length ps) (newType loc "t" Observe) ctEq (toType (Scalar (tupleRecord ps_t))) (toType t) TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = @@ -473,7 +481,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) L.sort (map fst p_fs) == L.sort (M.keys t_fs) = RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do - p_fs' :: M.Map Name Type <- traverse (const $ newType loc "t") $ M.fromList p_fs + p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs ctEq (Scalar (Record p_fs')) $ toType t st <- asStructType loc $ Scalar (Record p_fs') checkPat' p $ Ascribed $ toParam Observe st @@ -526,15 +534,15 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do - p_t <- newType (srclocOf p) "t" + p_t <- newType (srclocOf p) "t" Observe checkPat' p $ Ascribed p_t - t' <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' + t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' ctEq t' (toType t) t'' <- asStructType loc t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps - t <- newTypeWithConstr loc "t" n $ map (toType . patternType) ps' + t <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' t' <- asStructType loc t pure $ PatConstr n (Info $ toParam Observe t') ps' loc @@ -649,10 +657,10 @@ checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) split ftype' = do - a <- newType loc "arg" - b <- newTyVar loc "res" - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b - pure (a, tyVarType b) + a <- newType loc "arg" NoUniqueness + b <- newType loc "res" Nonunique + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + pure (a, b `setUniqueness` NoUniqueness) -- To be removed (probably) checkApply_ :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) @@ -685,10 +693,10 @@ checkApply_ loc _ ftype fframe arg = do split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) split ftype' = do - a <- newType loc "arg" - b <- newTyVar loc "res" - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] $ tyVarType b - pure (a, tyVarType b) + a <- newType loc "arg" NoUniqueness + b <- newType loc "res" Nonunique + ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + pure (a, b `setUniqueness` NoUniqueness) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex @@ -714,8 +722,8 @@ mustHaveFields loc t [f] ve_t = do rt :: Type <- newTypeWithField loc "ft" f ve_t ctEq t rt mustHaveFields loc t (f : fs) ve_t = do - ft :: Type <- newType loc "ft" - rt :: Type <- newTypeWithField loc "rt" f ft + ft <- newType loc "ft" NoUniqueness + rt <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t ctEq t rt @@ -796,7 +804,7 @@ checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc checkExp (Hole NoInfo loc) = - Hole <$> (Info <$> newType loc "hole") <*> pure loc + Hole <$> (Info <$> newType loc "hole" NoUniqueness) <*> pure loc checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (TupLit es loc) = @@ -826,7 +834,7 @@ checkExp (ArrayLit es _ loc) = do -- type variables for pathologically large arrays with -- type-unsuffixed integers. Add some special case that handles that -- more efficiently. - et <- newType loc "et" + et <- newType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e ctEq (expType e') (toType et) @@ -868,7 +876,7 @@ checkExp (Assert e1 e2 NoInfo loc) = do -- checkExp (Constr name es NoInfo loc) = do es' <- mapM checkExp es - t <- newTypeWithConstr loc "t" name $ map expType es' + t <- newTypeWithConstr loc "t" NoUniqueness name $ map expType es' pure $ Constr name es' (Info t) loc -- checkExp (AppExp (Apply fe args loc) NoInfo) = do @@ -910,7 +918,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e - t2 <- newType loc "t" + t2 <- newType loc "t" NoUniqueness t2' <- asStructType loc t2 let t1 = typeOf e' f1 = frameOf e' @@ -928,7 +936,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e - t1 <- newType loc "t" + t1 <- newType loc "t" NoUniqueness t1' <- asStructType loc t1 let t2 = typeOf e' f2 = frameOf e' @@ -946,8 +954,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do loc -- checkExp (ProjectSection fields NoInfo loc) = do - a <- newType loc "a" - b <- newType loc "b" + a <- newType loc "a" NoUniqueness + b <- newType loc "b" NoUniqueness mustHaveFields loc (toType a) fields (toType b) let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc @@ -1002,14 +1010,14 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end - range_t <- newTyVar loc "range" - ctEq (tyVarType range_t :: Type) (arrayOfRank 1 (expType start')) - pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes (tyVarType range_t) [] + range_t <- newType loc "range" NoUniqueness + ctEq (toType range_t) (arrayOfRank 1 (expType start')) + pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e - kt <- newType loc "kt" - t :: Type <- newTypeWithField loc "t" k kt + kt <- newType loc "kt" NoUniqueness + t <- newTypeWithField loc "t" k kt ctEq (expType e') t kt' <- asStructType loc kt pure $ Project k e' (Info kt') loc @@ -1022,9 +1030,9 @@ checkExp (RecordUpdate src fields ve NoInfo loc) = do -- checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice - index_arg_t <- newType loc "index" - index_elem_t <- newType loc "index_elem" - index_res_t :: Type <- newType loc "index_res" + index_arg_t <- newType loc "index" NoUniqueness + index_elem_t <- newType loc "index_elem" NoUniqueness + index_res_t <- newType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (toType index_arg_t) $ arrayOfRank num_slices index_elem_t ctEq index_res_t $ arrayOfRank (length slice) index_elem_t @@ -1035,8 +1043,8 @@ checkExp (IndexSection slice NoInfo loc) = do checkExp (AppExp (Index e slice loc) _) = do e' <- checkExp e slice' <- checkSlice slice - index_t <- newType loc "index" - index_elem_t <- newType loc "index_elem" + index_t <- newType loc "index" NoUniqueness + index_elem_t <- newType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (toType index_t) $ arrayOfRank num_slices index_elem_t ctEq (expType e') $ arrayOfRank (length slice) index_elem_t @@ -1047,7 +1055,7 @@ checkExp (Update src slice ve loc) = do slice' <- checkSlice slice ve' <- checkExp ve let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" + update_elem_t <- newType loc "update_elem" NoUniqueness ctEq (expType src') $ arrayOfRank (length slice) update_elem_t ctEq (expType ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc @@ -1059,7 +1067,7 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do slice' <- checkSlice slice ve' <- checkExp ve let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" + update_elem_t <- newType loc "update_elem" NoUniqueness ctEq (toType src_t) $ arrayOfRank (length slice) update_elem_t ctEq (expType ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do @@ -1100,7 +1108,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr - elem_t <- newType elemp "elem" + elem_t <- newType elemp "elem" NoUniqueness ctEq (expType arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body From 3212b972d4dd569cce24cec3e732f226569a456d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 14:35:10 -0800 Subject: [PATCH 100/296] Hack to detect no integer solutions from `glpk`. --- futhark.cabal | 1 + src/Futhark/Solve/GLPK.hs | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index 8cf87483f3..8c00ce32bc 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -497,6 +497,7 @@ library -- remove me later , process , glpk-hs + , silently executable futhark import: common diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index fe7ac5d129..b2d340d683 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -1,12 +1,14 @@ module Futhark.Solve.GLPK (glpk) where +import Control.Monad import Data.LinearProgram import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Solve.LP qualified as F +import System.IO.Silently -linearProgToGLPK :: (Show v, Ord v, Eq a, Num a, Group a) => F.LinearProg v a -> (LP v a) +linearProgToGLPK :: (Ord v, Eq a, Num a) => F.LinearProg v a -> (LP v a) linearProgToGLPK prog = LP { direction = cOptType $ F.optType prog, @@ -38,12 +40,16 @@ linearProgToGLPK prog = varList = S.toList $ F.vars prog -glpk :: - (Show v, Ord v, Show a, Eq a, Real a, Group a) => - F.LinearProg v a -> - IO (Maybe (Int, M.Map v Int)) +glpk :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) glpk lp = do + (output, res) <- capture $ glpk' lp + pure $ do + guard $ "PROBLEM HAS NO INTEGER FEASIBLE SOLUTION" `notElem` lines output + res + +glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) +glpk' lp = do (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres where - opts = mipDefaults {msgLev = MsgOff} + opts = mipDefaults {msgLev = MsgAll} From eb49178500a27cf98e5c9cab4707d6a7ec67eaa9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 16:45:17 -0800 Subject: [PATCH 101/296] Add some notes to clarify how this will actually work. --- src/Futhark/Internalise/Exps.hs | 92 +++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index c17db733ba..09285aacb2 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -888,6 +888,98 @@ withAutoMap_ ams arg_desc res_t args_e innerM = mapM_ addStms $ reverse stms innerM args +-- | Internalization of 'AutoMap'-annotated applications. +-- +-- Each application @f x@ has an annotation with @AutoMap R M F@ where +-- @R, M, F@ are the autorep, automap, and frame shapes, +-- respectively. +-- +-- The application @f x@ will have type @F t@ for some @t@, i.e. @(f +-- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely +-- it is the total accumulated shape that is due to implicit maps. +-- Another way of thinking about that is that @|F|@ is is the level +-- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ +-- then we know that @f x@ implicitly stands for +-- +-- > map (\x' -> map (\x'' -> f x'') x') x +-- +-- For an application with a non-empty autorep annotation, the frame +-- tells about how many dimensions of the replicate can be eliminated. +-- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: +-- +-- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} +-- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} +-- +-- All replicated arguments are pushed down the auto-map nest. Each +-- time a replicated argument is pushed down a level of an +-- automap-nest, one fewer replicates is needed (i.e., the outermost +-- dimension of @R@ can be dropped). Replicated arguments are pushed +-- down the nest until either 1) the bottom of the nest is encountered +-- or 2) no replicate dimensions remain. For example, in the second +-- application above @R@ = @F@, so we can push the replicated argument +-- down two levels. Since each level effectively removes a dimension +-- of the replicate, no replicates will be required: +-- +-- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] +-- +-- The number of replicates that are actually required is given by +-- max(|R| - |F|, 0). +-- +-- An expression's "true level" is the level at which that expression +-- will appear in the automap-nest. The bottom of a mapnest is level 0. +-- +-- * For annotations with @R = mempty@, the true level is @|F|@. +-- * For annotations with @M = mempty@, the true level is @|F| - |R|@. +-- +-- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) +-- will be required at the bottom of the mapnest. +-- +-- Note that replicates can only appear at the bottom of a mapnest; any +-- expression of the form +-- +-- > map (\ls x' rs -> e) (replicate x) +-- +-- can always be written as +-- +-- > map (\ls rs -> e[x' -> x]) +-- +-- Let's look at another example. Consider (with exact sizes omitted for brevity) +-- +-- > f : a -> a -> a -> []a -> [][][]a -> a +-- > xss : [][]a +-- > ys : []a +-- > zsss : [][][]a +-- > w : a +-- > vss : [][]a +-- +-- and the application +-- +-- > f xss ys zsss w vss +-- +-- which will have the following annotations +-- +-- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) +-- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) +-- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) +-- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) +-- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) +-- +-- This will yield the following mapnest. +-- +-- > map (\zss -> +-- > map (\xs zs vs -> +-- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss +-- +-- Let's see how we'd construct this mapnest from the annotations. We construct +-- the nest bottom-up. We have: +-- +-- Application | True level +-- --------------------------- +-- (1) | |[][]| = 2 +-- (2) | |[][]| - |[]| = 1 +-- (3) | |[][][]| = 3 +-- (4) | |[][][]| - |[][][][]| = -1 +-- (5) | |[][][]| - |[]| = 2 withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap ams arg_desc res_t args_e innerM = do (args, stms) <- From 94af0909b06bbe9bb324154aa59ee8f5e06dc90d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 22 Feb 2024 20:28:09 -0800 Subject: [PATCH 102/296] Bit more. --- src/Futhark/Internalise/Exps.hs | 38 +++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 09285aacb2..21620aa993 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -975,11 +975,41 @@ withAutoMap_ ams arg_desc res_t args_e innerM = -- -- Application | True level -- --------------------------- --- (1) | |[][]| = 2 --- (2) | |[][]| - |[]| = 1 --- (3) | |[][][]| = 3 +-- (1) | |[][]| = 2 +-- (2) | |[][]| - |[]| = 1 +-- (3) | |[][][]| = 3 -- (4) | |[][][]| - |[][][][]| = -1 --- (5) | |[][][]| - |[]| = 2 +-- (5) | |[][][]| - |[]| = 2 +-- +-- We start at level 0. +-- * Any argument with a negative true level of @-n@ will be replicated @n@ times; +-- the exact shapes can be found by removing the @F@ postfix from @R@, +-- i.e. @R = shapes_to_rep_by <> F@. +-- * Any argument with a 0 true level will be included. +-- * For any argument @arg@ with a positive true level, we construct a new parameter +-- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) +-- removed. +-- +-- Following the rules above, @w@ will be replicated once. For the remaining arguments, +-- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes +-- +-- > f x y z (replicate w) v +-- +-- At level l > 0: +-- * There are no replicates. +-- * Any argument with l true level will be included verbatim. +-- * Any argument with true level > l will have a new parameter constructed for it, +-- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. +-- * We surround the previous level with a map that binds that levels' new parameters +-- and is passed the current levels' arguments. +-- +-- Following the above recipe for level 1, we create parameters +-- @xs : []a, zs : []a, vs :[]a@ and obtain +-- +-- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs +-- +-- This process continues until the level is greater than the maximum +-- true level of any application, at which we terminate. withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap ams arg_desc res_t args_e innerM = do (args, stms) <- From a7c8dd9fff8f00e63125b6c40d9de60f8064e80e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 11:34:44 +0100 Subject: [PATCH 103/296] Unnecessary warnings. --- src/Language/Futhark/TypeChecker/Rank.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 1eef479dcf..6c2504ec3e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -186,7 +186,7 @@ solveRankILP loc prog = do [] -> typeError loc mempty "Rank ILP cannot be solved." rs -> do traceM "## rank maps" - forM (zip [0 :: Int ..] rs) $ \(i, r) -> + forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> traceM $ unlines $ "\n## rank map " <> prettyString i @@ -198,7 +198,7 @@ rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] rankAnalysis loc cs tyVars body = do rank_maps <- solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps - let bodys = map (flip updAM body) rank_maps + let bodys = map (`updAM` body) rank_maps pure $ zip cts_tyvars' bodys where splitFuncs @@ -335,7 +335,7 @@ updAM rank_map e = args' = fmap ( bimap - (fmap $ bimap id upd) + (fmap $ second upd) (updAM rank_map) ) args From 86acaed8db1f914d6318a5c79b4f37326248fd0b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 11:38:46 +0100 Subject: [PATCH 104/296] Slices better be i64. --- src/Language/Futhark/TypeChecker/Terms2.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d9e5930b71..024975b100 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -358,6 +358,9 @@ arrayOfRank :: Int -> Type -> Type arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp +require _why [pt] e = do + ctEq (Scalar $ Prim pt) (expType e) + pure e require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts ctEq t $ expType e @@ -702,11 +705,11 @@ checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where checkDimIndex (DimFix i) = - DimFix <$> check i + DimFix <$> (require "use as index" anySignedType =<< checkExp i) checkDimIndex (DimSlice i j s) = DimSlice <$> traverse check i <*> traverse check j <*> traverse check s - check = require "use as index" anySignedType <=< checkExp + check = require "use in slice" [Signed Int64] <=< checkExp isSlice :: DimIndexBase f vn -> Bool isSlice DimSlice {} = True From 92be29dca09a2dd295e084e9b0b9ca369d5c5a6e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 14:28:12 +0100 Subject: [PATCH 105/296] Implement occurs check. --- .../Futhark/TypeChecker/Constraints.hs | 43 ++++++++++++------- src/Language/Futhark/TypeChecker/Rank.hs | 1 + tests/{issue1599.fut => types/occurs.fut} | 1 + 3 files changed, 30 insertions(+), 15 deletions(-) rename tests/{issue1599.fut => types/occurs.fut} (53%) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ddaedd4025..ddb8e26e2c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -16,12 +16,11 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor -import Data.List qualified as L import Data.Map qualified as M import Data.Maybe +import Data.Set qualified as S import Data.Text qualified as T import Debug.Trace -import Futhark.IR.Pretty import Futhark.Util.Pretty import Language.Futhark @@ -38,7 +37,7 @@ data SComp instance Pretty SComp where pretty SDim = "[]" - pretty (SVar x) = brackets $ pretty x + pretty (SVar x) = brackets $ prettyName x instance Pretty (Shape SComp) where pretty = mconcat . map pretty . shapeDims @@ -59,7 +58,7 @@ data Ct instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM r m) = pretty r <+> "=" <+> "•" <+> "∨" <+> pretty m <+> "=" <+> "•" + pretty (CtAM r m) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -142,12 +141,25 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) +occursCheck :: VName -> Type -> SolveM () +occursCheck v tp = do + vars <- gets solverTyVars + let tp' = substTyVars vars tp + when (v `S.member` typeVars tp') . throwError . docText $ + "Occurs check: cannot instantiate" + <+> prettyName v + <+> "with" + <+> pretty tp + <> "." + subTyVar :: VName -> Int -> Type -> SolveM () -subTyVar v lvl t = +subTyVar v lvl t = do + occursCheck v t modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} linkTyVar :: VName -> VName -> SolveM () -linkTyVar v t = +linkTyVar v t = do + occursCheck v $ Scalar $ TypeVar NoUniqueness (qualName t) [] modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} -- Unify at the root, emitting new equalities that must hold. @@ -164,11 +176,11 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) Just $ M.elems $ M.intersectionWith (,) fs1 fs2 unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = - fmap concat - . forM (M.elems $ M.intersectionWith (,) cs1 cs2) - $ \(ts1, ts2) -> do - guard $ length ts1 == length ts2 - Just $ zip ts1 ts2 + fmap concat . forM cs' $ \(ts1, ts2) -> do + guard $ length ts1 == length ts2 + Just $ zip ts1 ts2 + where + cs' = M.elems $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = @@ -224,9 +236,10 @@ solveCt ct = solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = - second solution - . runExcept - . flip execStateT (initialState tyvars) - . runSolveM + trace (unlines (map prettyString constraints)) + $ second solution + . runExcept + . flip execStateT (initialState tyvars) + . runSolveM $ mapM solveCt constraints {-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 6c2504ec3e..fd9904998b 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -9,6 +9,7 @@ import Data.Map (Map) import Data.Map qualified as M import Data.Maybe import Debug.Trace +import Futhark.IR.Pretty () import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP diff --git a/tests/issue1599.fut b/tests/types/occurs.fut similarity index 53% rename from tests/issue1599.fut rename to tests/types/occurs.fut index 3ce47c38b1..c37b1448c4 100644 --- a/tests/issue1599.fut +++ b/tests/types/occurs.fut @@ -1,3 +1,4 @@ +-- Simple instance of an occurs check. -- == -- error: Occurs From aec83f333274151c9ee15b10b5948eeda802a1d3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 14:28:34 +0100 Subject: [PATCH 106/296] This is too much. --- src/Language/Futhark/TypeChecker/Constraints.hs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ddb8e26e2c..9565804534 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -236,10 +236,9 @@ solveCt ct = solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = - trace (unlines (map prettyString constraints)) - $ second solution - . runExcept - . flip execStateT (initialState tyvars) - . runSolveM + second solution + . runExcept + . flip execStateT (initialState tyvars) + . runSolveM $ mapM solveCt constraints {-# NOINLINE solve #-} From df35808a604bbb524a3e2f7b8fea0aef826f6c7c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 14:39:47 +0100 Subject: [PATCH 107/296] Add Futhark.Util.debugTraceM. --- src/Futhark/Internalise/Defunctionalise.hs | 11 +++++------ src/Futhark/Internalise/Exps.hs | 1 - src/Futhark/Util.hs | 8 ++++++++ src/Language/Futhark/TypeChecker.hs | 4 ++-- src/Language/Futhark/TypeChecker/Constraints.hs | 1 - src/Language/Futhark/TypeChecker/Rank.hs | 8 ++++---- src/Language/Futhark/TypeChecker/Terms.hs | 1 - src/Language/Futhark/TypeChecker/Terms2.hs | 9 ++++----- 8 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 249616ef67..98b698aeeb 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -14,10 +14,9 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Debug.Trace import Futhark.IR.Pretty () import Futhark.MonadFreshNames -import Futhark.Util (mapAccumLM, nubOrd) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types (Subst (..), applySubst) @@ -956,7 +955,7 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype - traceM $ + debugTraceM $ unlines [ "##defuncApplyArg LambdaSV", "## fname", @@ -985,7 +984,7 @@ defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] apply_e = mkApply f' [(argext, am, arg')] callret - traceM $ + debugTraceM $ unlines [ "##defuncApplyArg DynamicFun", "## f'", @@ -1046,7 +1045,7 @@ defuncApply f args appres loc = do -- ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) am_dims ams = NE.toList $ autoMap . snd . fst <$> args ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) ams - traceM $ + debugTraceM $ unlines [ "## defuncApply", "## f", @@ -1062,7 +1061,7 @@ defuncApply f args appres loc = do "## f type", prettyString $ typeOf f, "## arg types", - prettyString $ (typeOf . snd) <$> args, + prettyString $ typeOf . snd <$> args, "## ret_am", prettyString ret_am ] diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 21620aa993..5f933e659c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -16,7 +16,6 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 76fce3b4af..6d41b1c7f2 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -51,6 +51,7 @@ module Futhark.Util fixPoint, concatMapM, topologicalSort, + debugTraceM, ) where @@ -77,6 +78,7 @@ import Data.Text.Encoding qualified as T import Data.Text.Encoding.Error qualified as T import Data.Time.Clock (UTCTime, getCurrentTime) import Data.Tuple (swap) +import Debug.Trace import Numeric import System.Directory.Tree qualified as Dir import System.Environment @@ -507,3 +509,9 @@ topologicalSort dep nodes = modify $ second $ IM.insert i True mapM_ sorting $ mapMaybe (depends_of node) nodes_idx modify $ bimap (node :) (IM.insert i False) + +-- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to at least 1. +debugTraceM :: (Monad m) => String -> m () +debugTraceM + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 = traceM + | otherwise = const $ pure () diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 7ce81ecca5..2f3371d02b 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -26,8 +26,8 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S -import Debug.Trace import Futhark.FreshNames hiding (newName) +import Futhark.Util (debugTraceM) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Semantic @@ -715,7 +715,7 @@ checkValBind vb = do let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - traceM $ unlines ["# Inferred:", prettyString vb'] + debugTraceM $ unlines ["# Inferred:", prettyString vb'] pure ( mempty diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 9565804534..cf5a4f2dbc 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -20,7 +20,6 @@ import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.Util.Pretty import Language.Futhark diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index fd9904998b..83815ec70e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -8,11 +8,11 @@ import Data.List qualified as L import Data.Map (Map) import Data.Map qualified as M import Data.Maybe -import Debug.Trace import Futhark.IR.Pretty () import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP +import Futhark.Util (debugTraceM) import Language.Futhark hiding (ScalarType) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints @@ -178,7 +178,7 @@ enumerateRankSols prog = solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] solveRankILP loc prog = do - traceM $ + debugTraceM $ unlines [ "## solveRankILP", prettyString prog @@ -186,9 +186,9 @@ solveRankILP loc prog = do case enumerateRankSols prog of [] -> typeError loc mempty "Rank ILP cannot be solved." rs -> do - traceM "## rank maps" + debugTraceM "## rank maps" forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> - traceM $ + debugTraceM $ unlines $ "\n## rank map " <> prettyString i : map prettyString (M.toList r) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 8c48591696..76578151cb 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -27,7 +27,6 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Debug.Trace import Futhark.Util (mapAccumLM, nubOrd, topologicalSort) import Futhark.Util.Pretty hiding (space) import Language.Futhark diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 024975b100..43ef74251c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -55,10 +55,9 @@ import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.FreshNames qualified as FreshNames import Futhark.MonadFreshNames hiding (newName) -import Futhark.Util (mapAccumLM) +import Futhark.Util (debugTraceM, mapAccumLM) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Constraints @@ -1180,9 +1179,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - traceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - traceM $ + debugTraceM $ unlines [ "## cts:", unlines $ map prettyString cts @@ -1196,7 +1195,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do . uncurry solve forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> - traceM $ + debugTraceM $ unlines [ "## constraints:", unlines $ map prettyString cts', From 02c71b7c9f500ffac7e5fc2cbaa2f9a6e5499f03 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 06:04:13 -0800 Subject: [PATCH 108/296] Prettyprint AUTOMAP annotations. --- src/Language/Futhark/Pretty.hs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 81ca4a152f..886f69b066 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -311,11 +311,21 @@ prettyAppExp _ (If c t f _) = prettyAppExp p (Apply f args _) = parensIf (p >= 10) $ prettyExp 0 f - <+> hsep (map (prettyExp 10 . snd) $ NE.toList args) + <+> hsep (map prettyArg $ NE.toList args) + where + prettyArg (i, e) = + case unAnnot i of + Just (_, am) + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + parens (prettyExp 10 e <+> "Δ" <+> pretty am) + _ -> prettyExp 10 e instance (Eq vn, IsName vn, Annot f) => Pretty (AppExpBase f vn) where pretty = prettyAppExp (-1) +instance Pretty AutoMap where + pretty (AutoMap r m f) = encloseSep lparen rparen comma $ map pretty [r, m, f] + prettyInst :: (Annot f, Pretty t) => f t -> Doc a prettyInst t = case unAnnot t of From 2a64af32824306c56e3027837edbb50024c8c6db Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 06:15:16 -0800 Subject: [PATCH 109/296] Also prettyprint binops. --- src/Language/Futhark/Pretty.hs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 886f69b066..8519b22060 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -234,7 +234,13 @@ letBody body@(AppExp LetFun {} _) = pretty body letBody body = "in" <+> align (pretty body) prettyAppExp :: (Eq vn, IsName vn, Annot f) => Int -> AppExpBase f vn -> Doc a -prettyAppExp p (BinOp (bop, _) _ (x, _) (y, _) _) = prettyBinOp p bop x y +prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = + case (unAnnot xi, unAnnot yi) of + (Just (_, xam), Just (_, yam)) + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + -- fix + parens (prettyBinOp p bop x y <+> "Δ" <+> pretty xam <+> "Δ" <+> pretty yam) + _ -> prettyBinOp p bop x y prettyAppExp _ (Match e cs _) = "match" <+> pretty e (stack . map pretty) (NE.toList cs) prettyAppExp _ (Loop sizeparams pat initexp form loopbody _) = "loop" From d201507b14c3245500d0ec0a6b1a357e0e338e03 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 23 Feb 2024 16:00:18 +0100 Subject: [PATCH 110/296] Fix checkOneExp. --- src/Language/Futhark/TypeChecker/Terms.hs | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 76578151cb..3a1e9fc663 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1007,16 +1007,20 @@ checkApply _ _ _ _ _ = -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) -checkOneExp e = runTermTypeM checkExp mempty $ do - e' <- checkExp $ undefined e - let t = typeOf e' - (tparams, _, _) <- - letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t - fixOverloadedTypes $ typeVars t - e'' <- updateTypes e' - localChecks e'' - causalityCheck e'' - pure (tparams, e'') +checkOneExp e = do + (maybe_tysubsts, e') <- Terms2.checkSingleExp e + case maybe_tysubsts of + Left err -> typeError e' mempty $ pretty err + Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + e'' <- checkExp e' + let t = typeOf e'' + (tparams, _, _) <- + letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t + fixOverloadedTypes $ typeVars t + e''' <- updateTypes e'' + localChecks e''' + causalityCheck e''' + pure (tparams, e''') -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. From 2119f2971270ac34511ec505006544caac4e2e3d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 07:01:42 -0800 Subject: [PATCH 111/296] Fix frame duplication. --- src/Language/Futhark/TypeChecker/Terms2.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 43ef74251c..40e3ae83a1 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -637,8 +637,9 @@ checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r) <> (toSComp <$> argframe)) argtype + lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a + ctAM r m ctEq lhs rhs pure @@ -673,7 +674,7 @@ checkApply_ loc _ ftype fframe arg = do let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r) <> (toSComp <$> frameOf arg)) $ toType $ typeOf arg + lhs = arrayOf (toShape (SVar r)) $ toType $ typeOf arg rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a ctAM r m ctEq lhs rhs From 81d53e5a0b7d768742acab432db6ccfc31a59d46 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 07:46:47 -0800 Subject: [PATCH 112/296] Remove `checkApply_`. --- src/Language/Futhark/TypeChecker/Terms2.hs | 96 +++++++++------------- 1 file changed, 40 insertions(+), 56 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 40e3ae83a1..3f8b0df53b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -614,14 +614,12 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> [(Shape Size, Type)] -> TermM (StructType, [AutoMap]) +checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> NE.NonEmpty (Shape Size, Type) -> TermM (StructType, NE.NonEmpty AutoMap) checkApply loc fname (fframe, ftype) args = do ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args rt' <- asStructType loc rt pure (rt', argts) where - -- pure (asStructType loc rt, argts) - onArg (i, f_f, f_t) (argframe, argtype) = do (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) pure @@ -665,42 +663,6 @@ checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) --- To be removed (probably) -checkApply_ :: SrcLoc -> (Maybe (QualName VName), Int) -> Type -> Shape Size -> Exp -> TermM (Type, AutoMap) -checkApply_ loc _ ftype fframe arg = do - (a, b) <- split $ stripFrame fframe ftype - r <- newSVar loc "R" - m <- newSVar loc "M" - let unit_info = Info $ Scalar $ Prim Bool - r_var = Var (QualName [] r) unit_info mempty - m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r)) $ toType $ typeOf arg - rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a - ctAM r m - ctEq lhs rhs - pure - ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, - AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} - ) - where - stripFrame :: Shape Size -> Type -> Type - stripFrame frame (Array u ds t) = - let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) - in case mnew_shape of - Nothing -> Scalar t - Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t - stripFrame _ t = t - toSComp (Var (QualName [] x) _ _) = SVar x - toSComp _ = error "" - toShape = Shape . pure - split (Scalar (Arrow _ _ _ a (RetType _ b))) = - pure (a, b `setUniqueness` NoUniqueness) - split ftype' = do - a <- newType loc "arg" NoUniqueness - b <- newType loc "res" Nonunique - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b - pure (a, b `setUniqueness` NoUniqueness) - checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where @@ -884,22 +846,41 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args - rt' <- asStructType loc rt - pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] + (args', argts') <- + NE.unzip + <$> forM + args + ( \(_, arg) -> do + arg' <- checkExp arg + pure (arg', (frameOf arg', expType arg')) + ) + (rt, ams) <- checkApply loc fname (frameOf fe', expType fe') argts' + pure $ + AppExp (Apply fe' (NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args') loc) $ + Info $ + AppRes rt [] where fname = case fe of Var v _ _ -> Just v _ -> Nothing - - onArg (i, f_t, f_f) (_, arg) = do - arg' <- checkExp arg - (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' - pure - ( (i + 1, rt, autoFrame am), - (Info (Nothing, am), arg') - ) +-- fe' <- checkExp fe +-- ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args +-- rt' <- asStructType loc rt +-- pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] +-- where +-- fname = +-- case fe of +-- Var v _ _ -> Just v +-- _ -> Nothing + +-- onArg (i, f_t, f_f) (_, arg) = do +-- arg' <- checkExp arg +-- (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' +-- pure +-- ( (i + 1, rt, autoFrame am), +-- (Info (Nothing, am), arg') +-- ) -- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op @@ -910,8 +891,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do loc (Just op) (mempty, toType ftype) - [(frameOf e1', toType $ typeOf e1'), (frameOf e2', toType $ typeOf e2')] - let [am1, am2] = ams + ((frameOf e1', toType $ typeOf e1') NE.:| [(frameOf e2', toType $ typeOf e2')]) + let (am1 NE.:| [am2]) = ams pure $ AppExp @@ -925,13 +906,15 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do t2' <- asStructType loc t2 let t1 = typeOf e' f1 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(f1, toType t1), (mempty, t2)] + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((f1, toType t1) NE.:| [(mempty, t2)]) + + let (am1 NE.:| _) = ams pure $ OpSectionLeft op (Info optype) e' - ( Info (Unnamed, toParam Observe t1, Nothing, head ams), -- fix + ( Info (Unnamed, toParam Observe t1, Nothing, am1), Info (Unnamed, toParam Observe t2') ) (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) @@ -943,7 +926,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do t1' <- asStructType loc t1 let t2 = typeOf e' f2 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) [(mempty, t1), (f2, toType t2)] + (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((mempty, t1) NE.:| [(f2, toType t2)]) + let (_ NE.:| [am2]) = ams pure $ OpSectionRight op @@ -951,7 +935,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do e' -- Dummy types. ( Info (Unnamed, toParam Observe t1'), - Info (Unnamed, toParam Observe t2, Nothing, ams !! 1) -- fix + Info (Unnamed, toParam Observe t2, Nothing, am2) ) (Info $ RetType [] (rt `setUniqueness` Nonunique)) loc From 6b171951cdabb1f2026f6b8a313b9f2f9ab59a9e Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 07:48:20 -0800 Subject: [PATCH 113/296] Forgot to remove this too. --- src/Language/Futhark/TypeChecker/Terms2.hs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 3f8b0df53b..b3a3106435 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -864,24 +864,6 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do case fe of Var v _ _ -> Just v _ -> Nothing --- fe' <- checkExp fe --- ((_, rt, _), args') <- mapAccumLM onArg (0, expType fe', frameOf fe') args --- rt' <- asStructType loc rt --- pure $ AppExp (Apply fe' args' loc) $ Info $ AppRes rt' [] --- where --- fname = --- case fe of --- Var v _ _ -> Just v --- _ -> Nothing - --- onArg (i, f_t, f_f) (_, arg) = do --- arg' <- checkExp arg --- (rt, am) <- checkApply_ loc (fname, i) f_t f_f arg' --- pure --- ( (i + 1, rt, autoFrame am), --- (Info (Nothing, am), arg') --- ) --- checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 From e09ad671ddaa2e0b7e735c0f1f614c3f2ae45bed Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 17:41:59 -0800 Subject: [PATCH 114/296] Rank fixes. --- .../Futhark/TypeChecker/Constraints.hs | 2 +- src/Language/Futhark/TypeChecker/Rank.hs | 39 +++++++++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index cf5a4f2dbc..63e2320a47 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -71,7 +71,7 @@ data TyVarInfo TyVarRecord (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum (M.Map Name [Type]) - deriving (Show) + deriving (Show, Eq) instance Pretty TyVarInfo where pretty TyVarFree = "free" diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 83815ec70e..7fffe64107 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -55,12 +55,16 @@ instance Rank Type where class Distribute a where distribute :: a -> a -instance Distribute Type where - distribute = distributeOne - where - distributeOne (Array _ s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ Arrow NoUniqueness Unnamed mempty (arrayOf s ta) (RetType rd $ arrayOfWithAliases Nonunique s tr) - distributeOne t = t +instance Distribute (TypeBase dim u) where + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases Nonunique s tr)) + distribute t = t instance Distribute Ct where distribute (CtEq t1 t2) = distribute t1 `CtEq` distribute t2 @@ -169,6 +173,7 @@ enumerateRankSols prog = run_glpk = unsafePerformIO . glpk next_sol m = do (prog', sol') <- m + guard (fst sol' /= 0) let prog'' = ambigCheckLinearProg prog' sol' sol'' <- run_glpk prog'' pure (prog'', sol'') @@ -197,11 +202,12 @@ solveRankILP loc prog = do rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] rankAnalysis loc cs tyVars body = do - rank_maps <- solveRankILP loc (mkLinearProg (foldMap splitFuncs cs) tyVars) + rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps let bodys = map (`updAM` body) rank_maps pure $ zip cts_tyvars' bodys where + cs' = foldMap (splitFuncs . distribute) cs splitFuncs ( CtEq (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) @@ -288,9 +294,20 @@ rankToShape x = do addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks - unless (fromMaybe 0 (rs M.!? t) == 0) $ do - new_vars <- gets substNewVars - maybe new_var (const $ pure ()) $ new_vars M.!? t + if (fromMaybe 0 (rs M.!? t) == 0) + then do + old_tyvars <- asks envTyVars + case old_tyvars M.!? t of + -- Probably not needed + -- Just (lvl, TyVarFree) -> + -- -- is anyPrimType right here? + -- modify $ + -- \s -> s {substTyVars = M.insert t (lvl, TyVarPrim anyPrimType) $ substTyVars s} + _ -> do + pure () + else do + new_vars <- gets substNewVars + maybe new_var (const $ pure ()) $ new_vars M.!? t where new_var = do t' <- newTyVar t @@ -336,7 +353,7 @@ updAM rank_map e = args' = fmap ( bimap - (fmap $ second upd) + (fmap $ bimap id upd) (updAM rank_map) ) args From e2505f5acaca043ca4a4004ab20300ca32a350ce Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 17:43:54 -0800 Subject: [PATCH 115/296] Don't peel frames and distribute instead. --- src/Language/Futhark/TypeChecker/Terms.hs | 23 ++++++-- src/Language/Futhark/TypeChecker/Terms2.hs | 68 +++++++++++++++++----- 2 files changed, 73 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 3a1e9fc663..2107f56fd0 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -27,7 +27,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Futhark.Util (mapAccumLM, nubOrd, topologicalSort) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd, topologicalSort) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Primitive (intByteSize) @@ -941,13 +941,28 @@ checkApply :: Exp -> AutoMap -> TermTypeM (StructType, StructType, Maybe VName, [VName], AutoMap) -checkApply loc (fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do +checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 let (am_frame_shape, argtype_automap) = splitArrayAt (autoFrameRank am) argtype_with_frame + debugTraceM $ + unlines + [ "## checkApply", + "## fn", + prettyString fn, + "## ft", + prettyString ft, + "## tp1_with_frame", + prettyString tp1_with_frame, + "## argtype_with_frame", + prettyString argtype_with_frame, + "## am", + show am + ] + unify (mkUsage argexp "use as function argument") tp1_with_frame argtype_with_frame -- Perform substitutions of instantiated variables in the types. @@ -990,8 +1005,8 @@ checkApply loc (fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let am = AutoMap - { autoMap = am_map_shape, - autoRep = mempty, + { autoRep = mempty, + autoMap = am_map_shape, autoFrame = am_map_shape <> am_frame_shape } diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b3a3106435..59182007b5 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -628,41 +628,77 @@ checkApply loc fname (fframe, ftype) args = do ) checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) -checkApplyOne loc _ (fframe, ftype) (argframe, argtype) = do - (a, b) <- split $ stripFrame fframe ftype +checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do + (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" let unit_info = Info $ Scalar $ Prim Bool r_var = Var (QualName [] r) unit_info mempty m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r)) argtype - rhs = arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) a - + rhs = arrayOf (toShape (SVar m)) a ctAM r m ctEq lhs rhs + debugTraceM $ + unlines $ + [ "## checkApplyOne", + "## fname", + prettyString fname, + "## (fframe, ftype)", + prettyString (fframe, ftype), + "## (argframe, argtype)", + prettyString (argframe, argtype), + "## r", + prettyString r, + "## m", + prettyString m, + "## lhs", + prettyString lhs, + "## rhs", + prettyString rhs, + "## ret", + prettyString $ arrayOf (toShape (SVar m)) b + ] pure - ( arrayOf (toShape (SVar m) <> (toSComp <$> fframe)) b, + ( arrayOf (toShape (SVar m)) b, AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where - stripFrame :: Shape Size -> Type -> Type - stripFrame frame (Array u ds t) = - let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) - in case mnew_shape of - Nothing -> Scalar t - Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t - stripFrame _ t = t + -- stripFrame :: Shape Size -> Type -> Type + -- stripFrame frame (Array u ds t) = + -- let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) + -- in case mnew_shape of + -- Nothing -> Scalar t + -- Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t + -- stripFrame _ t = t + + isFunType (Scalar Arrow {}) = True + isFunType _ = False -- (fix) toSComp (Var (QualName [] x) _ _) = SVar x toSComp _ = error "" toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) + split (Array u s t) = do + (a, b) <- split $ Scalar t + pure (arrayOf s a, arrayOf s b) split ftype' = do a <- newType loc "arg" NoUniqueness b <- newType loc "res" Nonunique ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) +distribute :: TypeBase dim u -> TypeBase dim u +distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) +distribute t = t + checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where @@ -1151,7 +1187,11 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do debugTraceM $ unlines [ "## cts:", - unlines $ map prettyString cts + unlines $ map prettyString cts, + "## body:", + prettyString body', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' @@ -1166,7 +1206,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines [ "## constraints:", unlines $ map prettyString cts', - "## tyvars:", + "## tyvars':", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (v, t) = prettyNameString v <> " => " <> prettyString t From f25bd9e57cec8fcf49ba48d7e27c83223df2029c Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 18:56:41 -0800 Subject: [PATCH 116/296] Change objective from `M_i + R_i` to `M_i + max(0, |R_i| - |F_i|).` --- src/Futhark/Solve/LP.hs | 22 ++++++++++++++++- .../Futhark/TypeChecker/Constraints.hs | 6 ++--- src/Language/Futhark/TypeChecker/Rank.hs | 24 ++++++++++++++----- src/Language/Futhark/TypeChecker/Terms.hs | 8 ++++++- src/Language/Futhark/TypeChecker/Terms2.hs | 6 ++--- 5 files changed, 52 insertions(+), 14 deletions(-) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 044f6efe63..a2224617ea 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -8,6 +8,8 @@ module Futhark.Solve.LP cval, bin, or, + min, + max, oneIsZero, (~+~), (~-~), @@ -45,7 +47,7 @@ import Futhark.Solve.Matrix (Matrix (..)) import Futhark.Solve.Matrix qualified as M import Futhark.Util.Pretty import Language.Futhark.Pretty -import Prelude hiding (or) +import Prelude hiding (max, min, or) import Prelude qualified -- | A linear program. 'LP c a d' represents the program @@ -203,6 +205,24 @@ linearProgToPulp prog = bigM :: (Num a) => a bigM = 2 ^ 10 +-- max{x, y} = z +max :: (Eq a, Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] +max b x y z = + [ z ~>=~ x, + z ~>=~ y, + z ~<=~ x ~+~ bigM ~*~ var b, + z ~<=~ y ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +-- min{x, y} = z +min :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +min b x y z = + [ var z ~<=~ var x, + var z ~<=~ var y, + var z ~>=~ var x ~-~ bigM ~*~ (constant 1 ~-~ var b), + var z ~>=~ var y ~-~ bigM ~*~ var b + ] + oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = mkC b1 x1 diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 63e2320a47..ad11df4729 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -52,12 +52,12 @@ toType = first (const SDim) data Ct = CtEq Type Type - | CtAM SVar SVar + | CtAM SVar SVar (Shape SComp) deriving (Show) instance Pretty Ct where pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM r m) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" + pretty (CtAM r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -190,7 +190,7 @@ solveCt :: Ct -> SolveM () solveCt ct = case ct of CtEq t1 t2 -> solveCt' (t1, t2) - CtAM _ _ -> pure () -- Good vibes only. + CtAM _ _ _ -> pure () -- Good vibes only. where bad = throwError $ "Unsolvable: " <> prettyText ct solveCt' (t1, t2) = do diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 7fffe64107..c890ddaca5 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -73,7 +73,8 @@ instance Distribute Ct where data RankState = RankState { rankBinVars :: Map VName VName, rankCounter :: !Int, - rankConstraints :: [Constraint] + rankConstraints :: [Constraint], + rankObj :: LSum } newtype RankM a = RankM {runRankM :: State RankState a} @@ -106,12 +107,22 @@ addConstraints cs = addConstraint :: Constraint -> RankM () addConstraint = addConstraints . pure +addObj :: SVar -> RankM () +addObj sv = + modify $ \s -> s {rankObj = rankObj s ~+~ var sv} + addCt :: Ct -> RankM () addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 -addCt (CtAM r m) = do +addCt (CtAM r m f) = do b_r <- binVar r b_m <- binVar m + b_max <- VName "b_max" <$> incCounter + tr <- VName ("T_" <> baseName r) <$> incCounter + addConstraints $ [bin b_max, var b_max ~<=~ var tr] addConstraints $ oneIsZero (b_r, r) (b_m, m) + addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) + addObj m + addObj tr addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () addTyVarInfo _ (_, TyVarFree) = pure () @@ -126,9 +137,9 @@ mkLinearProg :: [Ct] -> TyVars -> LinearProg mkLinearProg cs tyVars = LP.LinearProg { optType = Minimize, - objective = - let shape_vars = M.keys $ rankBinVars finalState - in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, + objective = rankObj finalState, + -- let shape_vars = M.keys $ rankBinVars finalState + -- in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, constraints = rankConstraints finalState } where @@ -136,7 +147,8 @@ mkLinearProg cs tyVars = RankState { rankBinVars = mempty, rankCounter = 0, - rankConstraints = mempty + rankConstraints = mempty, + rankObj = constant 0 } buildLP = do mapM_ addCt cs diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2107f56fd0..8ad8b2a366 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1620,7 +1620,13 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case (maybe_tysubstss, bodys') of ([], _) -> error "impossible" ([maybe_tysubsts], [body']) -> doChecks (maybe_tysubsts, params', retdecl', body') - _ -> typeError loc mempty "Rank ILP is ambiguous" + (substs, bodies') -> + typeError loc mempty $ + stack $ + [ "Rank ILP is ambiguous.", + "Choices:" + ] + ++ map pretty bodies' where -- TODO: Print out the possibilities. (And also potentially eliminate --- some of the possibilities to disambiguate). diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 59182007b5..5bf91a47fe 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -283,8 +283,8 @@ ctEq t1 t2 = t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness -ctAM :: SVar -> SVar -> TermM () -ctAM r m = addCt $ CtAM r m +ctAM :: SVar -> SVar -> Shape SComp -> TermM () +ctAM r m f = addCt $ CtAM r m f localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -637,7 +637,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m)) a - ctAM r m + ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs debugTraceM $ unlines $ From 93fa76e46bbaa2798a4ad9aa6de8d4bd438618a2 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 19:04:38 -0800 Subject: [PATCH 117/296] Add some tests. --- tests/automap/ambiguous0.fut | 1 + tests/automap/bool1.fut | 6 +++ tests/automap/equality1.fut | 23 ++++++++++ tests/automap/lambda.fut | 6 +++ tests/automap/map0.fut | 8 ++++ tests/automap/mri-q-qr.fut | 2 + tests/automap/mri-q.fut | 40 +++++++++++++++++ tests/automap/operator1.fut | 9 ++++ tests/automap/optionpricing.fut | 78 +++++++++++++++++++++++++++++++++ tests/automap/pagerank.fut | 18 ++++++++ tests/automap/project.fut | 9 ++++ tests/automap/projsec1.fut | 9 ++++ tests/automap/same_typevar.fut | 16 +++++++ tests/automap/sgemm.fut | 32 ++++++++++++++ tests/automap/simple1.fut | 7 +++ tests/automap/simple2.fut | 8 ++++ tests/automap/simple3.fut | 8 ++++ tests/automap/simple4.fut | 8 ++++ tests/automap/simple5.fut | 6 +++ 19 files changed, 294 insertions(+) create mode 100644 tests/automap/ambiguous0.fut create mode 100644 tests/automap/bool1.fut create mode 100644 tests/automap/equality1.fut create mode 100644 tests/automap/lambda.fut create mode 100644 tests/automap/map0.fut create mode 100644 tests/automap/mri-q-qr.fut create mode 100644 tests/automap/mri-q.fut create mode 100644 tests/automap/operator1.fut create mode 100644 tests/automap/optionpricing.fut create mode 100644 tests/automap/pagerank.fut create mode 100644 tests/automap/project.fut create mode 100644 tests/automap/projsec1.fut create mode 100644 tests/automap/same_typevar.fut create mode 100644 tests/automap/sgemm.fut create mode 100644 tests/automap/simple1.fut create mode 100644 tests/automap/simple2.fut create mode 100644 tests/automap/simple3.fut create mode 100644 tests/automap/simple4.fut create mode 100644 tests/automap/simple5.fut diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut new file mode 100644 index 0000000000..58a663bf36 --- /dev/null +++ b/tests/automap/ambiguous0.fut @@ -0,0 +1 @@ +def ambig (xss : [][]i32) = i64.sum (length xss) diff --git a/tests/automap/bool1.fut b/tests/automap/bool1.fut new file mode 100644 index 0000000000..f3fe08213e --- /dev/null +++ b/tests/automap/bool1.fut @@ -0,0 +1,6 @@ +-- == +-- entry: f +-- input { [true, true, false] [false, true, true] } +-- output { [true, true, true] } + +def f [m] (xs: [m]bool) (ys: [m]bool) = xs || ys diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut new file mode 100644 index 0000000000..1604c49d93 --- /dev/null +++ b/tests/automap/equality1.fut @@ -0,0 +1,23 @@ +-- == +-- entry: bigger_to_smaller +-- input { [[1,2],[3,4]] [1,2] } +-- output { [true, false] } + +-- == +-- entry: smaller_to_bigger +-- input { [[1,2],[3,4]] [1,2] } +-- output { [true, false] } + +-- == +-- entry: smaller_to_bigger2 +-- input { [[1,2],[3,4]] 1 } +-- output { [[true,false],[false,false]]} + +def bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = + xss == ys + +def smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = + ys == xss + +def smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = + z == xss diff --git a/tests/automap/lambda.fut b/tests/automap/lambda.fut new file mode 100644 index 0000000000..1bb7ed26e3 --- /dev/null +++ b/tests/automap/lambda.fut @@ -0,0 +1,6 @@ +-- == +-- entry: main +-- random input { [10]f32 [10]f32 } + +entry main [n](xs: [n]f32) (ys: [n]f32): [n]f32 = + map2 (*) xs ys diff --git a/tests/automap/map0.fut b/tests/automap/map0.fut new file mode 100644 index 0000000000..a5ab0887ae --- /dev/null +++ b/tests/automap/map0.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [0,1,2,3] } +-- output { [1,2,3,4] } + +def automap 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = f as + +entry main (x: []i32) = automap (+1) x diff --git a/tests/automap/mri-q-qr.fut b/tests/automap/mri-q-qr.fut new file mode 100644 index 0000000000..8004f7da5d --- /dev/null +++ b/tests/automap/mri-q-qr.fut @@ -0,0 +1,2 @@ +def qr [numX][numK] (expArgs : [numX][numK]f32) (phiMag : [numK]f32) : [numX]f32 = + f32.sum (f32.cos expArgs * phiMag) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut new file mode 100644 index 0000000000..eaed14333a --- /dev/null +++ b/tests/automap/mri-q.fut @@ -0,0 +1,40 @@ +-- == +-- entry: main +-- random input { [12]f32 [12]f32 [12]f32 [10]f32 [10]f32 [10]f32 [12]f32 [12]f32 } +-- output { true } + +def main_orig [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numX]f32, [numX]f32) = + let phiMag = map2 (\r i -> r*r + i*i) phiR phiI + let expArgs = map3 (\x_e y_e z_e -> + map (2.0f32*f32.pi*) + (map3 (\kx_e ky_e kz_e -> + kx_e * x_e + ky_e * y_e + kz_e * z_e) + kx ky kz)) + x y z + let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs + let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs + in (qr, qi) + +def main_am [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numK]f32, [numX][numK]f32) = + let (phiMag : [numK]f32) = phiR * phiR + phiI * phiI + let (expArgs : [numX][numK]f32) = map3 (\(x_e : f32) (y_e : f32) (z_e : f32) -> + 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) + x y z + in (phiMag, expArgs) + --let (qr : [numX]f32) = f32.sum (f32.cos expArgs * phiMag) -- [numx]f32 + --let (qi : [numX]f32) = f32.sum (f32.sin expArgs * phiMag) -- let (qi_10408: artificial₁₁₄_10524 ~ [M113_10523]f32) + --in (qr, qi) + +--entry main [numK][numX] +-- (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) +-- (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) +-- (phiR: [numK]f32) (phiI: [numK]f32) = +-- main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI diff --git a/tests/automap/operator1.fut b/tests/automap/operator1.fut new file mode 100644 index 0000000000..464a8b79c4 --- /dev/null +++ b/tests/automap/operator1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [10,20] } +-- output { [[11, 22],[13, 24]] } + +def (+^) [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = xs + ys + +--entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = +-- xss +^ ys diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut new file mode 100644 index 0000000000..c58bc39a0a --- /dev/null +++ b/tests/automap/optionpricing.fut @@ -0,0 +1,78 @@ +-- == +-- entry: sobolIndR +-- random input { [12][10]i32 i32 } +-- output { true } + +-- == +-- entry: sobolRecI +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +-- == +-- entry: sobolReci2 +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +def grayCode(x: i32): i32 = (x >> 1) ^ x + +def testBit(n: i32, ind: i32): bool = + let t = (1 << ind) in (n & t) == t + +def xorInds [num_bits] (n: i32) (dir_vs: [num_bits]i32): i32 = + let reldv_vals = map (\(dv: i32, i): i32 -> + if testBit(grayCode(n),i32.i64 i) + then dv else 0 + ) (zip (dir_vs) (iota(num_bits)) ) in + reduce (^) 0 (reldv_vals ) + + +def sobolIndI [len] (dir_vs: [len][]i32, n: i32 ): [len]i32 = + map (xorInds(n)) (dir_vs ) + +def index_of_least_significant_0(num_bits: i32, n: i32): i32 = + let (goon,k) = (true,0) in + let (_,k,_) = loop ((goon,k,n)) for i < num_bits do + if(goon) + then if (n & 1) == 1 + then (true, k+1, n>>1) + else (false,k, n ) + else (false,k, n ) + in k + +def recM [len][num_bits] (sob_dirs: [len][num_bits]i32, i: i32 ): [len]i32 = + let bit= index_of_least_significant_0(i32.i64 num_bits,i) in + map (\(row: []i32): i32 -> row[bit]) (sob_dirs ) + +def sobolIndR_orig [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = map (xorInds n) dir_vs + in map (\x -> f32.i32(x) / divisor) arri + +def sobolRecI_orig [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in map2 (\vct_row prev -> vct_row[bit] ^ prev) sob_dir_vs prev + +def sobolReci2_orig [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + let col = recM(sob_dirs, i) + in map2 (^) prev col + +def sobolIndR_am [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = xorInds n dir_vs + in f32.i32 arri / divisor + +def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in sob_dir_vs[:,bit] ^ prev + +def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + prev ^ recM(sob_dirs, i) + +entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): []bool = + sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n + +entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): []bool = + sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x) + +entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): []bool = + sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i) diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut new file mode 100644 index 0000000000..c444932de5 --- /dev/null +++ b/tests/automap/pagerank.fut @@ -0,0 +1,18 @@ +-- == +-- entry: calculate_dangling_ranks +-- random input { [12]f32 [12]i32} +-- output { true } + +def calculate_dangling_ranks_orig [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let zipped = zip sizes ranks + let weights = map (\(size, rank) -> if size == 0 then rank else 0f32) zipped + let total = f32.sum weights / f32.i64 n + in map (+total) ranks + +def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let weights = f32.bool (sizes == 0) * ranks + let total = f32.sum weights / f32.i64 n + in ranks + total + +entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): []bool = + calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes diff --git a/tests/automap/project.fut b/tests/automap/project.fut new file mode 100644 index 0000000000..2902d0565a --- /dev/null +++ b/tests/automap/project.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in xsys.0 ++ xsys.1 + diff --git a/tests/automap/projsec1.fut b/tests/automap/projsec1.fut new file mode 100644 index 0000000000..485c977bc5 --- /dev/null +++ b/tests/automap/projsec1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in (.0) xsys ++ (.1) xsys + diff --git a/tests/automap/same_typevar.fut b/tests/automap/same_typevar.fut new file mode 100644 index 0000000000..260a00b785 --- /dev/null +++ b/tests/automap/same_typevar.fut @@ -0,0 +1,16 @@ +-- == +-- tags { no_wasm } +-- entry: big_to_small +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +-- == +-- entry: small_to_big +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +def f 'a (x: a) (y: a) (z: a) = (x, y, z) + +entry big_to_small [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f xss ys z + +entry small_to_big [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f z ys xss diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut new file mode 100644 index 0000000000..56dc08eb7e --- /dev/null +++ b/tests/automap/sgemm.fut @@ -0,0 +1,32 @@ +-- == +-- entry: main +-- random input { [5][10]f32 [10][3]f32 [5][3]f32 f32 f32 } +-- output { true } + +def mult_orig [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + let dotprod xs ys = f32.sum (map2 (*) xs ys) + in map (\xs -> map (dotprod xs) (transpose yss)) xss + +def add [n][m] (xss: [n][m]f32, yss: [n][m]f32): [n][m]f32 = + map2 (map2 (+)) xss yss + +def scale [n][m] (xss: [n][m]f32, a: f32): [n][m]f32 = + map (map1 (*a)) xss + +def main_orig [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + add(scale(css,beta), scale(mult_orig(ass,bss), alpha)) + + +def mult_am [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + f32.sum ((transpose (replicate p xss)) * (replicate n (transpose yss))) + +def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + css*beta + mult_am(ass,bss)*alpha + +entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) = + main_orig ass bss css alpha beta == main_am ass bss css alpha beta diff --git a/tests/automap/simple1.fut b/tests/automap/simple1.fut new file mode 100644 index 0000000000..f8833bb3b6 --- /dev/null +++ b/tests/automap/simple1.fut @@ -0,0 +1,7 @@ +-- == +-- entry: main +-- input { [1,2] 10 } +-- output { [11, 12] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + xs + y diff --git a/tests/automap/simple2.fut b/tests/automap/simple2.fut new file mode 100644 index 0000000000..ac57abcbe0 --- /dev/null +++ b/tests/automap/simple2.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + xss + ys + diff --git a/tests/automap/simple3.fut b/tests/automap/simple3.fut new file mode 100644 index 0000000000..adc60bd43f --- /dev/null +++ b/tests/automap/simple3.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + ys + xss + diff --git a/tests/automap/simple4.fut b/tests/automap/simple4.fut new file mode 100644 index 0000000000..d94bbe4a6b --- /dev/null +++ b/tests/automap/simple4.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { 3 [1,1] [[1,2],[3,4]] } +-- output { [[5,6],[7,8]] } + +entry main [n] (x : i32) (ys: [n]i32) (zss : [n][n]i32) : [n][n]i32 = + x + ys + zss + diff --git a/tests/automap/simple5.fut b/tests/automap/simple5.fut new file mode 100644 index 0000000000..46610e6567 --- /dev/null +++ b/tests/automap/simple5.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1,2,3] 4 } +-- output { [5, 6, 7] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + (\x y -> x + y) xs y From b3d3c423e03222c8a488b77f366e7c5a090ad87a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 19:40:25 -0800 Subject: [PATCH 118/296] Renaming hack to not count these when looking for new solutions. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c890ddaca5..186f8d0fb3 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -116,7 +116,7 @@ addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 addCt (CtAM r m f) = do b_r <- binVar r b_m <- binVar m - b_max <- VName "b_max" <$> incCounter + b_max <- VName "c_max" <$> incCounter tr <- VName ("T_" <> baseName r) <$> incCounter addConstraints $ [bin b_max, var b_max ~<=~ var tr] addConstraints $ oneIsZero (b_r, r) (b_m, m) From 3cdce28d1a63497ccfba1ea82dd983593ebc603f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 23 Feb 2024 22:23:50 -0800 Subject: [PATCH 119/296] Add leetcode test. --- tests/automap/leetcode.fut | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/automap/leetcode.fut diff --git a/tests/automap/leetcode.fut b/tests/automap/leetcode.fut new file mode 100644 index 0000000000..43a50cb2b8 --- /dev/null +++ b/tests/automap/leetcode.fut @@ -0,0 +1,4 @@ +def outerprod f x y = map (f >-> flip map y) x +def bidd A = outerprod (==) (indices A) (indices A) +def xmat A = bidd A || reverse (bidd A) +def check_matrix (A : [][]i32) = xmat A == (A != 0) |> flatten |> and From 6616a07928260326e1029e8759552dfce4f54655 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 24 Feb 2024 21:17:24 +0100 Subject: [PATCH 120/296] Easier to read with some linebreaks. --- src/Language/Futhark/Pretty.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 8519b22060..3318fe11e4 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -239,7 +239,7 @@ prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = (Just (_, xam), Just (_, yam)) | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> -- fix - parens (prettyBinOp p bop x y <+> "Δ" <+> pretty xam <+> "Δ" <+> pretty yam) + parens $ align $ prettyBinOp p bop x y "Δ" <+> pretty xam "Δ" <+> pretty yam _ -> prettyBinOp p bop x y prettyAppExp _ (Match e cs _) = "match" <+> pretty e (stack . map pretty) (NE.toList cs) prettyAppExp _ (Loop sizeparams pat initexp form loopbody _) = From 2de3f1eb7537530bc123897aa17c3eb154936d34 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 25 Feb 2024 10:49:45 +0100 Subject: [PATCH 121/296] Remove map intrinsic. AUTOMAP all the way! --- prelude/soacs.fut | 2 +- prelude/zip.fut | 18 ++++++------------ src/Futhark/Internalise/Exps.hs | 9 +-------- src/Language/Futhark/Interpreter.hs | 16 ---------------- src/Language/Futhark/Prop.hs | 10 ---------- 5 files changed, 8 insertions(+), 47 deletions(-) diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 310fad5421..9cda4d2e69 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map f as + f as -- | Apply the given function to each element of a single array. -- diff --git a/prelude/zip.fut b/prelude/zip.fut index 1171820307..48816fe97a 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -7,12 +7,6 @@ -- The main reason this module exists is that we need it to define -- SOACs like `map2`. --- We need a map to define some of the zip variants, but this file is --- depended upon by soacs.fut. So we just define a quick-and-dirty --- internal one here that uses the intrinsic version. -local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map f as - -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = intrinsics.zip as bs @@ -23,15 +17,15 @@ def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = -- | As `zip2`@term, but with one more array. def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c): *[n](a,b,c) = - internal_map (\(a,(b,c)) -> (a,b,c)) (zip as (zip2 bs cs)) + (\(a,(b,c)) -> (a,b,c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d): *[n](a,b,c,d) = - internal_map (\(a,(b,c,d)) -> (a,b,c,d)) (zip as (zip3 bs cs ds)) + (\(a,(b,c,d)) -> (a,b,c,d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e): *[n](a,b,c,d,e) = - internal_map (\(a,(b,c,d,e)) -> (a,b,c,d,e)) (zip as (zip4 bs cs ds es)) + (\(a,(b,c,d,e)) -> (a,b,c,d,e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. def unzip [n] 'a 'b (xs: [n](a,b)): ([n]a, [n]b) = @@ -43,18 +37,18 @@ def unzip2 [n] 'a 'b (xs: [n](a,b)): ([n]a, [n]b) = -- | As `unzip2`@term, but with one more array. def unzip3 [n] 'a 'b 'c (xs: [n](a,b,c)): ([n]a, [n]b, [n]c) = - let (as, bcs) = unzip (internal_map (\(a,b,c) -> (a,(b,c))) xs) + let (as, bcs) = unzip ((\(a,b,c) -> (a,(b,c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. def unzip4 [n] 'a 'b 'c 'd (xs: [n](a,b,c,d)): ([n]a, [n]b, [n]c, [n]d) = - let (as, bs, cds) = unzip3 (internal_map (\(a,b,c,d) -> (a,b,(c,d))) xs) + let (as, bs, cds) = unzip3 ((\(a,b,c,d) -> (a,b,(c,d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a,b,c,d,e)): ([n]a, [n]b, [n]c, [n]d, [n]e) = - let (as, bs, cs, des) = unzip4 (internal_map (\(a,b,c,d,e) -> (a,b,c,(d,e))) xs) + let (as, bs, cs, des) = unzip4 ((\(a,b,c,d,e) -> (a,b,c,(d,e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 5f933e659c..fab4677f2f 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1024,7 +1024,6 @@ withAutoMap ams arg_desc res_t args_e innerM = do where stripAutoMapDims i am = am {autoMap = E.Shape $ drop i $ E.shapeDims $ autoMap am} - autoMapRank = E.shapeRank . autoMap max_am = maximumBy (\x y -> E.shapeRank x `compare` E.shapeRank y) $ fmap autoMap ams inner_t = E.stripArray (E.shapeRank max_am) res_t ds = map autoMapRank ams @@ -1701,7 +1700,7 @@ findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName, AutoMap)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) - | E.Hole (Info t) loc <- f = + | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where onArg (Info (argext, am), e) = (e, argext, am) @@ -1859,12 +1858,6 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing - handleSOACs [lam, arr] "map" = Just $ \desc -> do - arr' <- internaliseExpToVars "map_arr" arr - arr_ts <- mapM lookupType arr' - lam' <- internaliseLambdaCoerce lam $ map rowType arr_ts - let w = arraysSize 0 arr_ts - letTupExp' desc $ I.Op $ I.Screma w arr' (I.mapSOAC lam') handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 527c734cbf..37358dcda1 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -1577,22 +1577,6 @@ initialCtx = Just $ fun2 stream def s | "reduce_stream" `isPrefixOf` s = Just $ fun3 $ \_ f arg -> stream f arg - def "map" = Just $ - TermPoly Nothing $ \t eval' -> do - t' <- evalType eval' mempty t - pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> - case unfoldFunType t' of - ([_, _], ret_t) - | Just rowshape <- typeRowShape ret_t -> - toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) - | otherwise -> - error $ "Bad return type: " <> prettyString ret_t - _ -> - error $ - "Invalid arguments to map intrinsic:\n" - ++ unlines [prettyString t, show f, show xs] - where - typeRowShape = sequenceA . structTypeShape . stripArray 1 def s | "reduce" `isPrefixOf` s = Just $ fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 6c69e1ef3d..280b531286 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -840,16 +840,6 @@ intrinsics = $ array_a Unique $ shape [m, k, l] ), - ( "map", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), - array_a Observe $ shape [n] - ] - $ RetType [] - $ array_b Unique - $ shape [n] - ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] From d24831df9f64058670f2a648003694537a6b79a8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 25 Feb 2024 07:13:27 -0800 Subject: [PATCH 122/296] Fix AUTOMAP shapes for the replicate case. --- src/Language/Futhark/TypeChecker/Terms.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 8ad8b2a366..a661507772 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -946,7 +946,10 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d onFailure (CheckingApply fname argexp tp1 argtype) $ do (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 - let (am_frame_shape, argtype_automap) = splitArrayAt (autoFrameRank am) argtype_with_frame + (am_frame_shape, argtype_automap) <- + if autoMapRank am == 0 + then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 + else pure $ splitArrayAt (autoFrameRank am) argtype_with_frame debugTraceM $ unlines @@ -1005,7 +1008,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d let am = AutoMap - { autoRep = mempty, + { autoRep = am_rep_shape, autoMap = am_map_shape, autoFrame = am_map_shape <> am_frame_shape } From 60db26b21d1a82c7ab92b3ebd3b92a35cb39c7bb Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 25 Feb 2024 07:17:34 -0800 Subject: [PATCH 123/296] Remove confusing/wrong name. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index a661507772..67caf7a3c7 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -946,7 +946,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d onFailure (CheckingApply fname argexp tp1 argtype) $ do (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 - (am_frame_shape, argtype_automap) <- + (am_frame_shape, _) <- if autoMapRank am == 0 then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 else pure $ splitArrayAt (autoFrameRank am) argtype_with_frame From 54882ab6b20fbfaf679c57e615fe39044e7610f9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 25 Feb 2024 07:27:02 -0800 Subject: [PATCH 124/296] Better frame computation + clarifying notes. --- src/Language/Futhark/TypeChecker/Terms.hs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 67caf7a3c7..e79fa461a0 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -944,12 +944,24 @@ checkApply :: checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do + -- argtype = arg_frame argtype' + -- tp1 = f_frame tp1' + -- + -- Rep case: + -- R arg_frame argtype' = f_frame tp1' + -- ==> R = (autoRepRank am)-length prefix of tp1 + -- ==> frame = f_frame = (autoFrameRank am)-length prefix of tp1 + -- + -- Map case: + -- arg_frame argtype' = M f_frame tp1' + -- ==> M = (autoMapRank am)-length prefix of argtype + -- ==> frame = M f_frame = (autoFrameRank am)-length prefix of argtype (am_map_shape, argtype_with_frame) <- splitArrayAt (autoMapRank am) <$> normTypeFully argtype (am_rep_shape, tp1_with_frame) <- splitArrayAt (autoRepRank am) <$> normTypeFully tp1 (am_frame_shape, _) <- if autoMapRank am == 0 then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 - else pure $ splitArrayAt (autoFrameRank am) argtype_with_frame + else splitArrayAt (autoFrameRank am) <$> normTypeFully argtype debugTraceM $ unlines @@ -1010,7 +1022,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d AutoMap { autoRep = am_rep_shape, autoMap = am_map_shape, - autoFrame = am_map_shape <> am_frame_shape + autoFrame = am_frame_shape } pure (tp1, distributeFrame (autoMap am) tp2'', argext, ext, am) From c04ade2456e19ae97be412612bbc2399db5034d9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 25 Feb 2024 17:54:35 +0100 Subject: [PATCH 125/296] Bump Nix and cabal. --- cabal.project | 2 +- nix/sources.json | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cabal.project b/cabal.project index f5a57ba11c..d8f795a28e 100644 --- a/cabal.project +++ b/cabal.project @@ -1,5 +1,5 @@ packages: futhark.cabal -index-state: 2024-01-24T22:19:37Z +index-state: 2024-02-25T13:57:21Z package futhark ghc-options: -j -fwrite-ide-info -hiedir=.hie diff --git a/nix/sources.json b/nix/sources.json index 1a95fd6f85..a60e35d52a 100644 --- a/nix/sources.json +++ b/nix/sources.json @@ -17,10 +17,10 @@ "homepage": "", "owner": "NixOS", "repo": "nixpkgs", - "rev": "2bcbada7a108ef5584abda1e36c42109d1f0d374", - "sha256": "12n79sl0nkp3b25ifdz9i8d9046g6dqz8g2jghg8d3836yjih7qj", + "rev": "efeff60fd4a0bc4f639a217a723f9e11df3f5e20", + "sha256": "09gxq604v7r9sl5qgp37n6414z2jivdjipwyrhka0d4rdhdbm31m", "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/2bcbada7a108ef5584abda1e36c42109d1f0d374.tar.gz", + "url": "https://github.com/NixOS/nixpkgs/archive/efeff60fd4a0bc4f639a217a723f9e11df3f5e20.tar.gz", "url_template": "https://github.com///archive/.tar.gz" } } From 6c91e26e392c6cadaa338d8ae37f8937a4ab486b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 25 Feb 2024 23:21:34 +0100 Subject: [PATCH 126/296] Work on supporting AUTOMAP in interpreter. --- src/Language/Futhark/Interpreter.hs | 117 +++++++++++++++++---- src/Language/Futhark/Interpreter/Values.hs | 20 +++- 2 files changed, 116 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 37358dcda1..39c278c4a7 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -414,6 +414,11 @@ fromArray :: Value -> (ValueShape, [Value]) fromArray (ValueArray shape as) = (shape, elems as) fromArray v = error $ "Expected array value, but found: " <> show v +fromArrayR :: Int -> Value -> [Value] +fromArrayR 0 v = [v] +fromArrayR 1 v = snd $ fromArray v +fromArrayR n v = concatMap (fromArrayR (n - 1)) $ snd $ fromArray v + apply :: SrcLoc -> Env -> Value -> Value -> EvalM Value apply loc env (ValueFun f) v = stacking loc env (f v) apply _ _ f _ = error $ "Cannot apply non-function: " <> show f @@ -423,6 +428,54 @@ apply2 loc env f x y = stacking loc env $ do f' <- apply noLoc mempty f x apply noLoc mempty f' y +data AutoMapArg + = -- | Map function across argument of this shape. + AutoMapMap [Int64] + | -- | Replicate argument to array of this shape. + AutoMapRep [Int64] + | AutoMapNone + deriving (Eq, Ord, Show) + +applyAM :: + SrcLoc -> + Env -> + (Value, StructType) -> + AutoMapArg -> + Value -> + EvalM Value +applyAM loc env (ValueArray _ xs, ft) AutoMapNone v = do + t' <- evalType (eval env) mempty ft + undefined +applyAM loc env (f, _) AutoMapNone v = + apply loc env f v +applyAM loc env (f, _) (AutoMapMap []) v = + apply loc env f v +applyAM loc env (f, _) (AutoMapRep []) v = + apply loc env f v +applyAM loc env (f, _) (AutoMapRep shape) v = + apply noLoc mempty f $ repArray shape v +-- The next case essentially implements the "map" primitive. +applyAM loc env (f, ft) (AutoMapMap shape) v = do + t' <- evalType (eval env) mempty ft + let rank = length shape + vs = fromArrayR rank v + case t' of + Scalar (Arrow _ _ _ _ (RetType _ ret_t@(Scalar Arrow {}))) + | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> do + fs <- mapM (apply noLoc mempty f) vs + pure $ ValueFun $ \v' -> + toArrayR shape rowshape + <$> zipWithM (apply loc env) fs (fromArrayR rank v') + Scalar (Arrow _ _ _ _ (RetType _ ret_t)) + | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> + toArrayR shape rowshape <$> mapM (apply noLoc mempty f) vs + | otherwise -> + error $ "Bad return type: " <> prettyString ret_t + _ -> + error $ + "Invalid automap arguments:\n" + ++ unlines [prettyString ft, show f, show v] + matchPat :: Env -> Pat (TypeBase Size u) -> Value -> EvalM Env matchPat env p v = do m <- runMaybeT $ patternMatch env p v @@ -752,13 +805,21 @@ evalFunctionBinding env tparams ps ret fbody = do returned env (retType ret) retext =<< evalFunction env' missing_sizes ps fbody (retType ret) -evalArg :: Env -> Exp -> Maybe VName -> EvalM Value -evalArg env e ext = do +evalArg :: Env -> Exp -> Maybe VName -> AutoMap -> EvalM (Value, AutoMapArg) +evalArg env e ext am = do v <- eval env e case ext of Just ext' -> putExtSize ext' v _ -> pure () - pure v + let evalShape = mapM (fmap asInt64 . eval env) . shapeDims + am' <- + if not $ null $ autoMap am + then AutoMapMap <$> evalShape (autoMap am) + else + if not $ null $ autoRep am + then AutoMapRep <$> evalShape (autoRep am) + else pure AutoMapNone + pure (v, am') returned :: Env -> TypeBase Size als -> [VName] -> Value -> EvalM Value returned _ _ [] v = pure v @@ -828,22 +889,31 @@ evalAppExp env (LetPat sizes p e body _) = do evalAppExp env (LetFun f (tparams, ps, _, Info ret, fbody) body _) = do binding <- evalFunctionBinding env tparams ps ret fbody eval (env {envTerm = M.insert f binding $ envTerm env}) body -evalAppExp env (BinOp (op, _) op_t (x, Info (xext, xam)) (y, Info (yext, yam)) loc) - | baseString (qualLeaf op) == "&&" = do +evalAppExp env (BinOp (op, _) (Info op_t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) + | baseString (qualLeaf op) == "&&", + noAutoMap = do x' <- asBool <$> eval env x if x' then eval env y else pure $ ValuePrim $ BoolValue False - | baseString (qualLeaf op) == "||" = do + | baseString (qualLeaf op) == "||", + noAutoMap = do x' <- asBool <$> eval env x if x' then pure $ ValuePrim $ BoolValue True else eval env y | otherwise = do - x' <- evalArg env x xext - y' <- evalArg env y yext - op' <- eval env $ Var op op_t loc - apply2 loc env op' x' y' + (x', xam') <- evalArg env x xext xam + (y', yam') <- evalArg env y yext yam + op' <- evalTermVar env op op_t + op'' <- applyAM loc env (op', op_t) xam' x' + applyAM loc env (op'', op_ret) yam' y' + where + op_ret = case op_t of + Scalar (Arrow _ _ _ _ (RetType _ t)) -> + toStruct t + _ -> error $ "Nonsensical binop type: " <> prettyString op_t + noAutoMap = xam == mempty && yam == mempty evalAppExp env (If cond e1 e2 _) = do cond' <- asBool <$> eval env cond if cond' then eval env e1 else eval env e2 @@ -853,9 +923,11 @@ evalAppExp env (Apply f args loc) = do -- type of the functions. args' <- reverse <$> mapM evalArg' (reverse $ NE.toList args) f' <- eval env f - foldM (apply loc env) f' args' + foldM apply' f' args' where - evalArg' (Info (ext, _), x) = evalArg env x ext + ft = typeOf f + apply' f' (v', am') = applyAM loc env (f', ft) am' v' + evalArg' (Info (ext, am), x) = evalArg env x ext am evalAppExp env (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e @@ -1047,16 +1119,21 @@ eval env (Lambda ps body _ (Info (RetType _ rt)) _) = evalFunction env [] ps body rt eval env (OpSection qv (Info t) _) = evalTermVar env qv $ toStruct t -eval env (OpSectionLeft qv _ e (Info (_, _, argext, _), _) (Info (RetType _ t), _) loc) = do - v <- evalArg env e argext - f <- evalTermVar env qv (toStruct t) - apply loc env f v -eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, _)) (Info (RetType _ t)) loc) = do - y <- evalArg env e argext +eval env (OpSectionLeft qv _ e (Info (_, _, argext, am), _) (Info (RetType _ t), _) loc) = do + (v, am') <- evalArg env e argext am + f <- evalTermVar env qv t' + applyAM loc env (f, t') am' v + where + t' = toStruct t +eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext, am)) (Info (RetType _ t)) loc) = do + (y, am') <- evalArg env e argext am pure $ ValueFun $ \x -> do - f <- evalTermVar env qv $ toStruct t - apply2 loc env f x y + f <- evalTermVar env qv t' + f' <- apply loc env f x + applyAM loc env (f', t') am' y + where + t' = toStruct t eval env (IndexSection is _ loc) = do is' <- mapM (evalDimIndex env) is pure $ ValueFun $ evalIndex loc env is' diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 40f0a8b287..b3fb36ac8c 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -20,7 +20,9 @@ module Language.Futhark.Interpreter.Values prettyEmptyArray, toArray, toArray', + toArrayR, toTuple, + repArray, -- * Conversion fromDataValue, @@ -28,7 +30,7 @@ module Language.Futhark.Interpreter.Values where import Data.Array -import Data.List (genericLength) +import Data.List (genericLength, genericReplicate) import Data.Map qualified as M import Data.Maybe import Data.Monoid hiding (Sum) @@ -206,6 +208,15 @@ toArray' rowshape vs = ValueArray shape (listArray (0, length vs - 1) vs) where shape = ShapeDim (genericLength vs) rowshape +-- | Produce multidimensional array from a flat list of values. +toArrayR :: [Int64] -> ValueShape -> [Value m] -> Value m +toArrayR [] _ = error "toArrayR: empty shape" +toArrayR [_] elemshape = toArray' elemshape +toArrayR (n : ns) elemshape = + toArray (foldr ShapeDim elemshape (n : ns)) + . map (toArrayR ns elemshape) + . chunk (fromIntegral (product ns)) + arrayLength :: (Integral int) => Array Int (Value m) -> int arrayLength = fromIntegral . (+ 1) . snd . bounds @@ -237,6 +248,13 @@ fromDataValueWith f shape vector where shape' = SVec.tail shape +repArray :: [Int64] -> Value m -> Value m +repArray [] v = v +repArray (n : ns) v = + toArray' (valueShape v') (genericReplicate n v') + where + v' = repArray ns v + -- | Convert a Futhark value in the externally observable data format -- to an interpreter value. fromDataValue :: V.Value -> Value m From 9919fec8dfb981695b2d7dca080735202e2ed9da Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 15:45:57 +0100 Subject: [PATCH 127/296] We must also touch automaps here. --- src/Language/Futhark/Traversals.hs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index fc20935c24..94b440b2ff 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -60,6 +60,13 @@ class ASTMappable x where -- into subexpressions. The mapping is done left-to-right. astMap :: (Monad m) => ASTMapper m -> x -> m x +mapOnAutoMap :: (Monad m) => ASTMapper m -> AutoMap -> m AutoMap +mapOnAutoMap tv (AutoMap r m f) = + AutoMap + <$> traverse (mapOnExp tv) r + <*> traverse (mapOnExp tv) m + <*> traverse (mapOnExp tv) f + instance ASTMappable (AppExpBase Info VName) where astMap tv (Range start next end loc) = Range @@ -73,7 +80,7 @@ instance ASTMappable (AppExpBase Info VName) where Match <$> mapOnExp tv e <*> astMap tv cases <*> pure loc astMap tv (Apply f args loc) = do f' <- mapOnExp tv f - args' <- traverse (traverse $ mapOnExp tv) args + args' <- traverse onArg args -- Safe to disregard return type because existentials cannot be -- instantiated here, as the return is necessarily a function. pure $ case f' of @@ -81,6 +88,9 @@ instance ASTMappable (AppExpBase Info VName) where Apply f_inner (args_inner <> args') loc _ -> Apply f' args' loc + where + onArg (Info (ext, am), e) = + (,) <$> (Info . (ext,) <$> mapOnAutoMap tv am) <*> mapOnExp tv e astMap tv (LetPat sizes pat e body loc) = LetPat sizes <$> astMap tv pat <*> mapOnExp tv e <*> mapOnExp tv body <*> pure loc astMap tv (LetFun name (tparams, params, ret, t, e) body loc) = @@ -101,13 +111,16 @@ instance ASTMappable (AppExpBase Info VName) where <*> mapOnExp tv vexp <*> mapOnExp tv body <*> pure loc - astMap tv (BinOp (fname, fname_loc) t (x, xext) (y, yext) loc) = + astMap tv (BinOp (fname, fname_loc) t x y loc) = BinOp <$> ((,) <$> mapOnName tv fname <*> pure fname_loc) <*> traverse (mapOnStructType tv) t - <*> ((,) <$> mapOnExp tv x <*> pure xext) - <*> ((,) <$> mapOnExp tv y <*> pure yext) + <*> onArg x + <*> onArg y <*> pure loc + where + onArg (e, Info (ext, am)) = + (,) <$> mapOnExp tv e <*> (Info . (ext,) <$> mapOnAutoMap tv am) astMap tv (Loop sparams mergepat mergeexp form loopbody loc) = Loop sparams <$> astMap tv mergepat From afa09d3e501d266db9bf8b0c213539908ab2283b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 16:31:24 +0100 Subject: [PATCH 128/296] Add another test program. --- tests/automap/combinations.fut | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/automap/combinations.fut diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut new file mode 100644 index 0000000000..5c49fb251f --- /dev/null +++ b/tests/automap/combinations.fut @@ -0,0 +1,33 @@ +-- All the various ways one can imagine automapping a very simple program. + +def plus (x: i32) (y: i32) = x + y + +-- == +-- entry: vecint +-- input { [1,2,3] } output { [3,4,5] } + +entry vecint (x: []i32) = plus x 2 + +-- == +-- entry: vecvec +-- input { [1,2,3] } output { [2,4,6] } + +entry vecvec (x: []i32) = plus x x + +-- == +-- entry: matint +-- input { [[1,2],[3,4]] } output { [[3,4],[5,6]] } + +entry matint (x: [][]i32) = plus x 2 + +-- == +-- entry: matmat +-- input { [[1,2],[3,4]] } output { [[2,4],[6,8]] } + +entry matmat (x: [][]i32) = plus x x + +-- == +-- entry: matvec +-- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } + +entry matvec (x: [][]i32) (y: []i32) = plus x y From 58c8ff33e06a8e91078126ae03763c1d7ddd0859 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 08:01:39 -0800 Subject: [PATCH 129/296] Distribute frames recursively. --- src/Language/Futhark/TypeChecker/Terms.hs | 15 +++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 11 ----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2b8b9f6c16..25c6dcb0c8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1025,11 +1025,18 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d autoFrame = am_frame_shape } - pure (tp1, distributeFrame (autoMap am) tp2'', argext, ext, am) + pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am) where - distributeFrame frame (Scalar (Arrow u p d a (RetType ds b))) = - Scalar $ Arrow u p d (arrayOf frame a) (RetType ds (arrayOfWithAliases (uniqueness b) frame b)) - distributeFrame frame t = arrayOf frame t + distribute :: TypeBase dim u -> TypeBase dim u + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) + distribute t = t checkApply _ _ _ _ _ = error "checkApply: array" diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 5bf91a47fe..ccba90ab07 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -688,17 +688,6 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) -distribute :: TypeBase dim u -> TypeBase dim u -distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ - Arrow - u - Unnamed - mempty - (arrayOf s ta) - (RetType rd $ distribute (arrayOfWithAliases (uniqueness tr) s tr)) -distribute t = t - checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where From 99fb122442709e4ab11e4d16caab276eb34cb8d9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 20:14:10 +0100 Subject: [PATCH 130/296] The vindication of Robert. --- src/Language/Futhark/Interpreter.hs | 49 +++++++---------------------- tests/automap/combinations.fut | 5 +++ 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 39c278c4a7..506a8b715d 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -429,11 +429,7 @@ apply2 loc env f x y = stacking loc env $ do apply noLoc mempty f' y data AutoMapArg - = -- | Map function across argument of this shape. - AutoMapMap [Int64] - | -- | Replicate argument to array of this shape. - AutoMapRep [Int64] - | AutoMapNone + = AutoMapArg [Int64] [Int64] [Int64] deriving (Eq, Ord, Show) applyAM :: @@ -443,38 +439,23 @@ applyAM :: AutoMapArg -> Value -> EvalM Value -applyAM loc env (ValueArray _ xs, ft) AutoMapNone v = do - t' <- evalType (eval env) mempty ft - undefined -applyAM loc env (f, _) AutoMapNone v = - apply loc env f v -applyAM loc env (f, _) (AutoMapMap []) v = - apply loc env f v -applyAM loc env (f, _) (AutoMapRep []) v = +applyAM loc env (f, _) (AutoMapArg [] [] []) v = apply loc env f v -applyAM loc env (f, _) (AutoMapRep shape) v = - apply noLoc mempty f $ repArray shape v --- The next case essentially implements the "map" primitive. -applyAM loc env (f, ft) (AutoMapMap shape) v = do +applyAM loc env (f, ft) am@(AutoMapArg repshape mapshape frame) v = do + let v' = repArray repshape v + f' = repArray mapshape f + rank = length frame + vs = fromArrayR rank v' + fs = fromArrayR rank f' t' <- evalType (eval env) mempty ft - let rank = length shape - vs = fromArrayR rank v case t' of - Scalar (Arrow _ _ _ _ (RetType _ ret_t@(Scalar Arrow {}))) - | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> do - fs <- mapM (apply noLoc mempty f) vs - pure $ ValueFun $ \v' -> - toArrayR shape rowshape - <$> zipWithM (apply loc env) fs (fromArrayR rank v') Scalar (Arrow _ _ _ _ (RetType _ ret_t)) | Just rowshape <- sequenceA $ structTypeShape $ toStruct ret_t -> - toArrayR shape rowshape <$> mapM (apply noLoc mempty f) vs - | otherwise -> - error $ "Bad return type: " <> prettyString ret_t + toArrayR frame rowshape <$> zipWithM (apply loc env) fs vs _ -> error $ "Invalid automap arguments:\n" - ++ unlines [prettyString ft, show f, show v] + ++ unlines [prettyString ft, show f, show v, show am] matchPat :: Env -> Pat (TypeBase Size u) -> Value -> EvalM Env matchPat env p v = do @@ -806,19 +787,13 @@ evalFunctionBinding env tparams ps ret fbody = do =<< evalFunction env' missing_sizes ps fbody (retType ret) evalArg :: Env -> Exp -> Maybe VName -> AutoMap -> EvalM (Value, AutoMapArg) -evalArg env e ext am = do +evalArg env e ext (AutoMap rshape mshape frame) = do v <- eval env e case ext of Just ext' -> putExtSize ext' v _ -> pure () let evalShape = mapM (fmap asInt64 . eval env) . shapeDims - am' <- - if not $ null $ autoMap am - then AutoMapMap <$> evalShape (autoMap am) - else - if not $ null $ autoRep am - then AutoMapRep <$> evalShape (autoRep am) - else pure AutoMapNone + am' <- AutoMapArg <$> evalShape rshape <*> evalShape mshape <*> evalShape frame pure (v, am') returned :: Env -> TypeBase Size als -> [VName] -> Value -> EvalM Value diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut index 5c49fb251f..7d77e85abb 100644 --- a/tests/automap/combinations.fut +++ b/tests/automap/combinations.fut @@ -31,3 +31,8 @@ entry matmat (x: [][]i32) = plus x x -- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } entry matvec (x: [][]i32) (y: []i32) = plus x y + +-- == +-- entry: vecvecvec +-- input { [1,2,3] } output { [3,6,9] } +entry vecvecvec (x: []i32) = (\x y z -> x + y + z) x x x From e4041695f2ae12c2710f2a2eaca5550fd14eafaf Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 21:13:47 -0800 Subject: [PATCH 131/296] Basic internalization working. --- src/Futhark/Internalise/Exps.hs | 157 +++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index fab4677f2f..3c6b7bb99c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -14,8 +14,10 @@ import Data.List (elemIndex, find, intercalate, intersperse, maximumBy, transpos import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M +import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T +import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings @@ -882,11 +884,16 @@ internalisePatLit l t = withAutoMap_ :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([[SubExp]] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap_ ams arg_desc res_t args_e innerM = - withAutoMap ams arg_desc res_t args_e $ \args_stms -> do + withAutoMapNew (zip3 args_e ams (repeat arg_desc)) $ \args_stms -> do let (args, stms) = unzip args_stms mapM_ addStms $ reverse stms innerM args +-- withAutoMap ams arg_desc res_t args_e $ \args_stms -> do +-- let (args, stms) = unzip args_stms +-- mapM_ addStms $ reverse stms +-- innerM args + -- | Internalization of 'AutoMap'-annotated applications. -- -- Each application @f x@ has an annotation with @AutoMap R M F@ where @@ -1009,6 +1016,154 @@ withAutoMap_ ams arg_desc res_t args_e innerM = -- -- This process continues until the level is greater than the maximum -- true level of any application, at which we terminate. +type Level = Int + +type ArgNum = Int + +type ArgMap = M.Map Level (M.Map ArgNum AutoMapArg) + +data AutoMapArg = AutoMapArg + { amArgs :: [VName], + amArgStms :: Stms SOACS + } + deriving (Show) + +data AutoMapParam = AutoMapParam + { amParams :: [LParam SOACS], + amParamStms :: Stms SOACS, + amMapDim :: SubExp + } + deriving (Show) + +withAutoMapNew :: [((E.Exp, Maybe VName), AutoMap, String)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] +withAutoMapNew args_am func = do + (param_maps, arg_maps) <- + unzip . reverse + <$> mapM buildArgMap (reverse args_am) + let param_map = M.unionsWith (++) $ (fmap . fmap) pure param_maps + arg_map = M.unionsWith (++) $ (fmap . fmap) pure arg_maps + traceM $ + unlines + [ "##param_map", + show param_map, + "##arg_map", + show arg_map + ] + buildMapNest param_map arg_map $ maximum $ M.keys arg_map + where + buildMapNest _ arg_map 0 = + func $ map (\a -> (map I.Var $ amArgs a, amArgStms a)) $ arg_map M.! 0 + buildMapNest param_map arg_map l = + case map amMapDim $ param_map M.! l of + [] -> buildMapNest param_map arg_map (l - 1) + (map_dim : _) -> do + let (params, p_stms) = + unzip $ + map (\p -> (amParams p, amParamStms p)) $ + param_map M.! l + (args, arg_stms) = + unzip $ + map (\a -> (amArgs a, amArgStms a)) $ + arg_map M.! l + letValExp' + "automap" + . Op + . Screma map_dim (concat args) + . mapSOAC + =<< mkLambda + (concat params) + ( do + subExpsRes <$> buildMapNest param_map arg_map (l - 1) + ) + + buildArgMap :: ((E.Exp, Maybe VName), AutoMap, String) -> InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap (arg, am, arg_desc) = do + ses <- internaliseArg arg_desc arg + arg_vnames <- mapM (letExp "" <=< eSubExp) ses + ts <- mapM subExpType ses + (p_map, a_map) <- + foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ + reverse [0 .. trueLevel am] + traceM $ + unlines + [ "##truelevel am", + show $ trueLevel am, + "## arg", + prettyString arg, + "## am", + show am + ] + + pure (p_map, a_map) + where + mkArgsAndParams arg_vnames ses ts (p_map, a_map) l + | l == 0 = do + let as = + fromMaybe + arg_vnames + ( ( map I.paramName + . amParams + ) + <$> p_map M.!? 1 + ) + (ses, stms) <- mkBottomArgs as ts + pure $ (p_map, M.insert 0 (AutoMapArg ses stms) a_map) + | l == trueLevel am = do + (ps, p_stms) <- mkParams arg_vnames ts l + d <- outerDim am l + pure + ( M.insert l (AutoMapParam ps p_stms d) p_map, + M.insert l (AutoMapArg arg_vnames mempty) a_map + ) + | l < trueLevel am && l > 0 = do + (ps, p_stms) <- mkParams arg_vnames ts l + d <- outerDim am l + let as = + map I.paramName $ + amParams $ + p_map M.! (l + 1) + pure + ( M.insert l (AutoMapParam ps p_stms d) p_map, + M.insert l (AutoMapArg as mempty) a_map + ) + | otherwise = error "" + + mkParams _ ts level = + collectStms $ + forM ts $ \t -> + newParam ("p_" <> arg_desc) $ argType (level - 1) am t + mkBottomArgs arg_vnames ts = + collectStms $ do + rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am + if I.shapeRank rep_shape > 0 + then concat <$> mapM (letValExp "autorep" . BasicOp . Replicate rep_shape . I.Var) arg_vnames + else pure arg_vnames + + argType level am t = I.stripArray (trueLevel am - level) t + + internaliseShape :: E.Shape Size -> InternaliseM I.Shape + internaliseShape = + fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims + + trueLevel :: AutoMap -> Int + trueLevel am + | autoMap am == mempty = max 0 $ E.shapeRank (autoFrame am) - E.shapeRank (autoRep am) + | otherwise = E.shapeRank $ autoFrame am + + outerDim :: AutoMap -> Int -> InternaliseM SubExp + outerDim am level = do + traceM $ + unlines + [ "##outerDim", + "##am", + show am, + "##level", + show level, + "## dff", + show (trueLevel am - level) + ] + internaliseExp1 "" $ (!! (trueLevel am - level)) $ E.shapeDims $ autoFrame am + withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] withAutoMap ams arg_desc res_t args_e innerM = do (args, stms) <- From 27588dc92ac5a4ddbb14a3fb0667377f7d6e9d4f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:05:19 -0800 Subject: [PATCH 132/296] Remove some complexity. --- src/Futhark/Internalise/Exps.hs | 261 +++++++------------------------- 1 file changed, 57 insertions(+), 204 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 3c6b7bb99c..d38e36493c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -17,7 +17,6 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T -import Debug.Trace import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings @@ -354,10 +353,8 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e - arg_desc = nameToString fname ++ "_arg" - args = map (\(a, b, _) -> (a, b)) argsam - ams = map (\(_, _, c) -> c) argsam - res_t = et + (args, ams) = unzip argsam + args_am_desc = zip3 args ams (repeat (nameToString fname ++ "_arg")) -- Some functions are magical (overloaded) and we handle that here. case () of @@ -367,20 +364,20 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - withAutoMap_ ams arg_desc res_t args $ \args' -> do - let prepareArg (arg, _, am) arg' = + withAutoMap args_am_desc $ \args' -> do + let prepareArg ((arg, _), am, _) arg' = (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') - internalise $ zipWith prepareArg argsam args' + internalise $ zipWith prepareArg args_am_desc args' | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, Just (rettype, _) <- M.lookup fname I.builtInFunctions -> - withAutoMap_ ams arg_desc res_t args $ \args' -> do + withAutoMap args_am_desc $ \args' -> do let tag ses = [(se, I.Observe) | se <- ses] let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - withAutoMap_ ams arg_desc res_t args $ \args' -> + withAutoMap args_am_desc $ \args' -> do funcall desc qfname (concat args') loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body @@ -882,18 +879,6 @@ internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) -withAutoMap_ :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([[SubExp]] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] -withAutoMap_ ams arg_desc res_t args_e innerM = - withAutoMapNew (zip3 args_e ams (repeat arg_desc)) $ \args_stms -> do - let (args, stms) = unzip args_stms - mapM_ addStms $ reverse stms - innerM args - --- withAutoMap ams arg_desc res_t args_e $ \args_stms -> do --- let (args, stms) = unzip args_stms --- mapM_ addStms $ reverse stms --- innerM args - -- | Internalization of 'AutoMap'-annotated applications. -- -- Each application @f x@ has an annotation with @AutoMap R M F@ where @@ -1018,128 +1003,104 @@ withAutoMap_ ams arg_desc res_t args_e innerM = -- true level of any application, at which we terminate. type Level = Int -type ArgNum = Int - -type ArgMap = M.Map Level (M.Map ArgNum AutoMapArg) - data AutoMapArg = AutoMapArg - { amArgs :: [VName], - amArgStms :: Stms SOACS + { amArgs :: [VName] } deriving (Show) data AutoMapParam = AutoMapParam { amParams :: [LParam SOACS], - amParamStms :: Stms SOACS, amMapDim :: SubExp } deriving (Show) -withAutoMapNew :: [((E.Exp, Maybe VName), AutoMap, String)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] -withAutoMapNew args_am func = do +withAutoMap :: + [((E.Exp, Maybe VName), AutoMap, String)] -> + ([[SubExp]] -> InternaliseM [SubExp]) -> + InternaliseM [SubExp] +withAutoMap args_am func = do (param_maps, arg_maps) <- unzip . reverse <$> mapM buildArgMap (reverse args_am) - let param_map = M.unionsWith (++) $ (fmap . fmap) pure param_maps - arg_map = M.unionsWith (++) $ (fmap . fmap) pure arg_maps - traceM $ - unlines - [ "##param_map", - show param_map, - "##arg_map", - show arg_map - ] + let param_map = M.unionsWith (<>) $ (fmap . fmap) pure param_maps + arg_map = M.unionsWith (<>) $ (fmap . fmap) pure arg_maps buildMapNest param_map arg_map $ maximum $ M.keys arg_map where buildMapNest _ arg_map 0 = - func $ map (\a -> (map I.Var $ amArgs a, amArgStms a)) $ arg_map M.! 0 + func $ map (map I.Var . amArgs) $ arg_map M.! 0 buildMapNest param_map arg_map l = case map amMapDim $ param_map M.! l of [] -> buildMapNest param_map arg_map (l - 1) (map_dim : _) -> do - let (params, p_stms) = - unzip $ - map (\p -> (amParams p, amParamStms p)) $ - param_map M.! l - (args, arg_stms) = - unzip $ - map (\a -> (amArgs a, amArgStms a)) $ - arg_map M.! l + let params = map amParams $ param_map M.! l + args = map amArgs $ arg_map M.! l letValExp' "automap" . Op . Screma map_dim (concat args) . mapSOAC =<< mkLambda - (concat params) - ( do - subExpsRes <$> buildMapNest param_map arg_map (l - 1) - ) + (concat params) + ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) + ) - buildArgMap :: ((E.Exp, Maybe VName), AutoMap, String) -> InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap :: + ((E.Exp, Maybe VName), AutoMap, String) -> + InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) buildArgMap (arg, am, arg_desc) = do ses <- internaliseArg arg_desc arg arg_vnames <- mapM (letExp "" <=< eSubExp) ses ts <- mapM subExpType ses - (p_map, a_map) <- - foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ - reverse [0 .. trueLevel am] - traceM $ - unlines - [ "##truelevel am", - show $ trueLevel am, - "## arg", - prettyString arg, - "## am", - show am - ] - - pure (p_map, a_map) + foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ + reverse [0 .. trueLevel am] where mkArgsAndParams arg_vnames ses ts (p_map, a_map) l | l == 0 = do let as = - fromMaybe + maybe arg_vnames - ( ( map I.paramName - . amParams - ) - <$> p_map M.!? 1 + ( map I.paramName + . amParams ) - (ses, stms) <- mkBottomArgs as ts - pure $ (p_map, M.insert 0 (AutoMapArg ses stms) a_map) + (p_map M.!? 1) + ses <- mkBottomArgs as ts + pure (p_map, M.insert 0 (AutoMapArg ses) a_map) | l == trueLevel am = do - (ps, p_stms) <- mkParams arg_vnames ts l + ps <- mkParams arg_vnames ts l d <- outerDim am l pure - ( M.insert l (AutoMapParam ps p_stms d) p_map, - M.insert l (AutoMapArg arg_vnames mempty) a_map + ( M.insert l (AutoMapParam ps d) p_map, + M.insert l (AutoMapArg arg_vnames) a_map ) | l < trueLevel am && l > 0 = do - (ps, p_stms) <- mkParams arg_vnames ts l + ps <- mkParams arg_vnames ts l d <- outerDim am l let as = map I.paramName $ amParams $ p_map M.! (l + 1) pure - ( M.insert l (AutoMapParam ps p_stms d) p_map, - M.insert l (AutoMapArg as mempty) a_map + ( M.insert l (AutoMapParam ps d) p_map, + M.insert l (AutoMapArg as) a_map ) | otherwise = error "" mkParams _ ts level = - collectStms $ - forM ts $ \t -> - newParam ("p_" <> arg_desc) $ argType (level - 1) am t - mkBottomArgs arg_vnames ts = - collectStms $ do - rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am - if I.shapeRank rep_shape > 0 - then concat <$> mapM (letValExp "autorep" . BasicOp . Replicate rep_shape . I.Var) arg_vnames - else pure arg_vnames - - argType level am t = I.stripArray (trueLevel am - level) t + forM ts $ \t -> + newParam ("p_" <> arg_desc) $ argType (level - 1) am t + mkBottomArgs arg_vnames ts = do + rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am + if I.shapeRank rep_shape > 0 + then + concat + <$> mapM + ( letValExp "autorep" + . BasicOp + . Replicate rep_shape + . I.Var + ) + arg_vnames + else pure arg_vnames internaliseShape :: E.Shape Size -> InternaliseM I.Shape internaliseShape = @@ -1151,118 +1112,10 @@ withAutoMapNew args_am func = do | otherwise = E.shapeRank $ autoFrame am outerDim :: AutoMap -> Int -> InternaliseM SubExp - outerDim am level = do - traceM $ - unlines - [ "##outerDim", - "##am", - show am, - "##level", - show level, - "## dff", - show (trueLevel am - level) - ] + outerDim am level = internaliseExp1 "" $ (!! (trueLevel am - level)) $ E.shapeDims $ autoFrame am -withAutoMap :: [AutoMap] -> String -> StructType -> [(E.Exp, Maybe VName)] -> ([([SubExp], Stms SOACS)] -> InternaliseM [SubExp]) -> InternaliseM [SubExp] -withAutoMap ams arg_desc res_t args_e innerM = do - (args, stms) <- - foldM - ( \(args, stms) arg -> do - (arg', stms') <- inScopeOf (reverse stms) $ collectStms $ internaliseArg arg_desc arg - pure (arg' : args, stms' : stms) - ) - (mempty, mempty) - (reverse args_e) - argts <- inScopeOf (reverse stms) $ (mapM . mapM) subExpType args - expand args stms argts ams (maximum ds) - where - stripAutoMapDims i am = - am {autoMap = E.Shape $ drop i $ E.shapeDims $ autoMap am} - max_am = maximumBy (\x y -> E.shapeRank x `compare` E.shapeRank y) $ fmap autoMap ams - inner_t = E.stripArray (E.shapeRank max_am) res_t - ds = map autoMapRank ams - mkLambdaParams level (ses, ts, stm, d) - | d == level = - Left - <$> zipWithM - ( \se t -> do - let t' = I.stripArray 1 t - p <- newParam "x" t' - addStms stm - pure ((se, p), t') - ) - ses - ts - | otherwise = pure $ Right $ zip ses ts - - internaliseShape = - fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims - - addReplicates = - zipWithM - ( \am arg -> do - rep_shape <- - internaliseShape $ - autoRep am `E.shapePrefix` autoFrame am - if I.shapeRank rep_shape > 0 - then concat <$> mapM (letValExp' "autoRep" . BasicOp . Replicate rep_shape) arg - else pure arg - ) - - expand args stms argts ams' level - | level <= 0 = do - args' <- addReplicates ams' args - innerM $ zip args' stms - | otherwise = do - let ds' = map autoMapRank ams' - arg_params <- mapM (mkLambdaParams level) $ zip4 args argts stms ds' - let argts' = map (either (map snd) (map snd)) arg_params - (ams'', stms') = - unzip $ - zipWith - ( \am stm -> - if autoMapRank am == level - then (stripAutoMapDims 1 am, mempty) - else (am, stm) - ) - ams' - stms - args' = map (either (map (I.Var . I.paramName . snd . fst)) (map fst)) arg_params - (map_ses, params) = unzip $ (concatMap . map) fst $ lefts arg_params - - ((ses, ses_ts), lam_stms) <- collectStms $ localScope (scopeOfLParams params) $ do - ses <- expand args' stms' argts' ams'' (level - 1) - ses_ts <- internaliseLambdaReturnType (E.toRes Nonunique inner_t) =<< mapM subExpType ses - pure (ses, ses_ts) - - case map_ses of - [] -> pure mempty - (map_se : _) -> do - outer_shape <- I.takeDims 1 . I.arrayShape <$> subExpType map_se - let I.Shape [outer_shape_se] = outer_shape - map_args <- forM map_ses $ \se -> do - se_t <- subExpType se - se_name <- letExp "map_arg" =<< toExp se - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter outer_shape 1 $ I.arrayShape se_t) - se_name - - letValExp' "automap" - . Op - . Screma outer_shape_se map_args - . mapSOAC - =<< mkLambda - params - ( ensureResultShape - (ErrorMsg [ErrorString "AutoMap: unexpected lambda result size"]) - mempty - ses_ts - =<< (addStms lam_stms >> pure (subExpsRes ses)) - ) + argType level am = I.stripArray (trueLevel am - level) generateCond :: E.Pat StructType -> @@ -1851,14 +1704,14 @@ data Function | FunctionHole SrcLoc deriving (Show) -findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName, AutoMap)]) +findFuncall :: E.AppExp -> (Function, [((E.Exp, Maybe VName), AutoMap)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info (argext, am), e) = (e, argext, am) + onArg (Info (argext, am), e) = ((e, argext), am) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e From 8af019ebd6aea8715bb6095e1f7b27d54a2f3879 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:05:42 -0800 Subject: [PATCH 133/296] Fixes. --- tests/automap/equality1.fut | 10 +++++----- tests/automap/pagerank.fut | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut index 1604c49d93..b2a173f30d 100644 --- a/tests/automap/equality1.fut +++ b/tests/automap/equality1.fut @@ -1,23 +1,23 @@ -- == -- entry: bigger_to_smaller -- input { [[1,2],[3,4]] [1,2] } --- output { [true, false] } +-- output { [[true, true], [false, false]] } -- == -- entry: smaller_to_bigger -- input { [[1,2],[3,4]] [1,2] } --- output { [true, false] } +-- output { [[true, true], [false, false]] } -- == -- entry: smaller_to_bigger2 -- input { [[1,2],[3,4]] 1 } -- output { [[true,false],[false,false]]} -def bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = +entry bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = xss == ys -def smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n]bool = +entry smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = ys == xss -def smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = +entry smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = z == xss diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut index c444932de5..3552990144 100644 --- a/tests/automap/pagerank.fut +++ b/tests/automap/pagerank.fut @@ -14,5 +14,5 @@ def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = let total = f32.sum weights / f32.i64 n in ranks + total -entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): []bool = - calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes +entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): bool = + and (calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes) From 321b5245c1bbbcc824153869adf3e2a8a260503f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:33:36 -0800 Subject: [PATCH 134/296] Looks like we actually do need some reshaping. --- src/Futhark/Internalise/Exps.hs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index d38e36493c..bdef5c071e 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1034,15 +1034,21 @@ withAutoMap args_am func = do (map_dim : _) -> do let params = map amParams $ param_map M.! l args = map amArgs $ arg_map M.! l + + reshaped_args <- + forM (concat args) $ \argvn -> + letExp "reshaped" $ + shapeCoerce [map_dim] argvn + letValExp' "automap" . Op - . Screma map_dim (concat args) + . Screma map_dim reshaped_args . mapSOAC =<< mkLambda - (concat params) - ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) - ) + (concat params) + ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) + ) buildArgMap :: ((E.Exp, Maybe VName), AutoMap, String) -> From bb037270053e99ce46c8012a72dad862cd372e72 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:41:23 -0800 Subject: [PATCH 135/296] Oops. Fix reshaping. --- src/Futhark/Internalise/Exps.hs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index bdef5c071e..dd6f16a8a1 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1036,9 +1036,14 @@ withAutoMap args_am func = do args = map amArgs $ arg_map M.! l reshaped_args <- - forM (concat args) $ \argvn -> + forM (concat args) $ \argvn -> do + arg_t <- subExpType $ I.Var argvn letExp "reshaped" $ - shapeCoerce [map_dim] argvn + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) + argvn letValExp' "automap" From 10cfde83bb80d4e8f9187b79f345374091437dc8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 22:46:11 -0800 Subject: [PATCH 136/296] Better to do the reshaping here, I think. --- src/Futhark/Internalise/Exps.hs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index dd6f16a8a1..e1180e274e 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1035,20 +1035,10 @@ withAutoMap args_am func = do let params = map amParams $ param_map M.! l args = map amArgs $ arg_map M.! l - reshaped_args <- - forM (concat args) $ \argvn -> do - arg_t <- subExpType $ I.Var argvn - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) - argvn - letValExp' "automap" . Op - . Screma map_dim reshaped_args + . Screma map_dim (concat args) . mapSOAC =<< mkLambda (concat params) @@ -1079,9 +1069,20 @@ withAutoMap args_am func = do | l == trueLevel am = do ps <- mkParams arg_vnames ts l d <- outerDim am l + + reshaped_args <- + forM arg_vnames $ \argvn -> do + arg_t <- subExpType $ I.Var argvn + letExp "reshaped" $ + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter (I.Shape [d]) 1 $ I.arrayShape arg_t) + argvn + pure ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg arg_vnames) a_map + M.insert l (AutoMapArg reshaped_args) a_map ) | l < trueLevel am && l > 0 = do ps <- mkParams arg_vnames ts l From 5adef9022786be1ffe856ff9221fab9de6846a1d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 26 Feb 2024 20:46:51 +0100 Subject: [PATCH 137/296] Begin handling overloaded type variables. --- src/Language/Futhark/TypeChecker/Constraints.hs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ad11df4729..0725d50418 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -159,7 +159,15 @@ subTyVar v lvl t = do linkTyVar :: VName -> VName -> SolveM () linkTyVar v t = do occursCheck v $ Scalar $ TypeVar NoUniqueness (qualName t) [] + tyvars <- gets solverTyVars modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} + tyvars' <- + case (M.lookup v tyvars, M.lookup t tyvars) of + (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree)) -> + pure $ M.insert t (TyVarUnsol lvl info) tyvars + -- TODO: handle more cases. + _ -> pure tyvars + modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) tyvars'} -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -190,7 +198,7 @@ solveCt :: Ct -> SolveM () solveCt ct = case ct of CtEq t1 t2 -> solveCt' (t1, t2) - CtAM _ _ _ -> pure () -- Good vibes only. + CtAM {} -> pure () -- Good vibes only. where bad = throwError $ "Unsolvable: " <> prettyText ct solveCt' (t1, t2) = do From 1d9d25d897133feb0b99bd1562d259736e682b0d Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 23:27:51 -0800 Subject: [PATCH 138/296] These should be frames. --- src/Futhark/Internalise/Defunctionalise.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 98b698aeeb..375fe18577 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -972,7 +972,7 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a pure ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, - autoMapSV (autoMap am) sv + autoMapSV (autoFrame am) sv -- sv ) -- If 'f' is a dynamic function, we just leave the application in @@ -996,7 +996,7 @@ defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do "## ret sv", show $ autoMapSV (autoMap am) sv ] - pure (apply_e, autoMapSV (autoMap am) sv) + pure (apply_e, autoMapSV (autoFrame am) sv) -- pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = From 266368a9058e233928b2805ba485c9c4ee01f21a Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 26 Feb 2024 23:34:46 -0800 Subject: [PATCH 139/296] Revert "Better to do the reshaping here, I think." I was wrong. --- src/Futhark/Internalise/Exps.hs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index e1180e274e..dd6f16a8a1 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1035,10 +1035,20 @@ withAutoMap args_am func = do let params = map amParams $ param_map M.! l args = map amArgs $ arg_map M.! l + reshaped_args <- + forM (concat args) $ \argvn -> do + arg_t <- subExpType $ I.Var argvn + letExp "reshaped" $ + I.BasicOp $ + I.Reshape + I.ReshapeCoerce + (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) + argvn + letValExp' "automap" . Op - . Screma map_dim (concat args) + . Screma map_dim reshaped_args . mapSOAC =<< mkLambda (concat params) @@ -1069,20 +1079,9 @@ withAutoMap args_am func = do | l == trueLevel am = do ps <- mkParams arg_vnames ts l d <- outerDim am l - - reshaped_args <- - forM arg_vnames $ \argvn -> do - arg_t <- subExpType $ I.Var argvn - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter (I.Shape [d]) 1 $ I.arrayShape arg_t) - argvn - pure ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg reshaped_args) a_map + M.insert l (AutoMapArg arg_vnames) a_map ) | l < trueLevel am && l > 0 = do ps <- mkParams arg_vnames ts l From d190923710fd74e5831f8f26cd4a5d276f27c02f Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 27 Feb 2024 13:04:36 -0800 Subject: [PATCH 140/296] Strip off automapped shapes from arg static values. Fixes `optionpricing.fut` bug. --- src/Futhark/Internalise/Defunctionalise.hs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 375fe18577..2f9ea6d787 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -907,14 +907,18 @@ defuncApplyArg :: DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, am), arg), _) = do (arg', arg_sv) <- defuncExp arg - let env' = alwaysMatchPatSV pat arg_sv + let arg_sv' = + case arg_sv of + (Dynamic ty@(Array {})) -> Dynamic $ stripArray (shapeRank $ autoFrame am) ty + _ -> arg_sv dims = mempty + env' = alwaysMatchPatSV pat arg_sv' (lam_e', sv) <- localNewEnv (env' <> closure_env) $ defuncExp lam_e let closure_pat = buildEnvPat dims closure_env - pat' = updatePat pat arg_sv + pat' = updatePat pat arg_sv' globals <- asks fst From 56c18c05389095a987645dbeaa1feac2ab51b1ea Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 27 Feb 2024 13:07:15 -0800 Subject: [PATCH 141/296] Prevent loops. --- src/Futhark/Internalise/Defunctionalise.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 2f9ea6d787..2819f444c3 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -1223,7 +1223,7 @@ matchPatSV (PatConstr c1 _ ps _) (Dynamic (Scalar (Sum fs))) else Nothing | otherwise = error $ "matchPatSV: missing constructor in type: " ++ prettyString c1 -matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t +matchPatSV pat (Dynamic t@(Scalar Record {})) = matchPatSV pat $ svFromType t matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType $ toParam Observe t matchPatSV pat sv = error $ From 8e1acffa1b28c5c88e7e54d507193a407680ca21 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 28 Feb 2024 10:49:19 -0800 Subject: [PATCH 142/296] Don't need this. --- src/Futhark/Internalise/Exps.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index dd6f16a8a1..21f559c00c 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1062,10 +1062,10 @@ withAutoMap args_am func = do ses <- internaliseArg arg_desc arg arg_vnames <- mapM (letExp "" <=< eSubExp) ses ts <- mapM subExpType ses - foldM (mkArgsAndParams arg_vnames ses ts) (mempty, mempty) $ + foldM (mkArgsAndParams arg_vnames ts) (mempty, mempty) $ reverse [0 .. trueLevel am] where - mkArgsAndParams arg_vnames ses ts (p_map, a_map) l + mkArgsAndParams arg_vnames ts (p_map, a_map) l | l == 0 = do let as = maybe From 9e5d59a6ac82848cf7883b1da83fc2e899fdc741 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 28 Feb 2024 11:54:41 -0800 Subject: [PATCH 143/296] Undo AUTOMAP-handling in the later phases of internalization. --- src/Futhark/Internalise/Defunctionalise.hs | 95 +------ src/Futhark/Internalise/Exps.hs | 283 ++------------------- src/Futhark/Internalise/Monomorphise.hs | 28 +- 3 files changed, 43 insertions(+), 363 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 2819f444c3..82cc845d69 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -16,7 +16,7 @@ import Data.Maybe import Data.Set qualified as S import Futhark.IR.Pretty () import Futhark.MonadFreshNames -import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) +import Futhark.Util (mapAccumLM, nubOrd) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types (Subst (..), applySubst) @@ -905,20 +905,16 @@ defuncApplyArg :: (Exp, StaticVal) -> (((Maybe VName, AutoMap), Exp), [ParamType]) -> DefM (Exp, StaticVal) -defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, am), arg), _) = do +defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, _), arg), _) = do (arg', arg_sv) <- defuncExp arg - let arg_sv' = - case arg_sv of - (Dynamic ty@(Array {})) -> Dynamic $ stripArray (shapeRank $ autoFrame am) ty - _ -> arg_sv + let env' = alwaysMatchPatSV pat arg_sv dims = mempty - env' = alwaysMatchPatSV pat arg_sv' (lam_e', sv) <- localNewEnv (env' <> closure_env) $ defuncExp lam_e let closure_pat = buildEnvPat dims closure_env - pat' = updatePat pat arg_sv' + pat' = updatePat pat arg_sv globals <- asks fst @@ -959,49 +955,20 @@ defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) (((argext, a fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype - debugTraceM $ - unlines - [ "##defuncApplyArg LambdaSV", - "## fname", - fname_s, - "## f'", - prettyString f', - "## arg", - prettyString arg, - "## sv", - show sv, - "## ret sv", - show $ autoMapSV (autoMap am) sv - ] - pure - ( mkApply fname' [(Nothing, mempty, f'), (argext, am, arg')] callret, - autoMapSV (autoFrame am) sv - -- sv + ( mkApply fname' [(Nothing, mempty, f'), (argext, mempty, arg')] callret, + sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. -defuncApplyArg _ (f', DynamicFun _ sv) (((argext, am), arg), argtypes) = do +defuncApplyArg _ (f', DynamicFun _ sv) (((argext, _), arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] - apply_e = mkApply f' [(argext, am, arg')] callret - debugTraceM $ - unlines - [ "##defuncApplyArg DynamicFun", - "## f'", - prettyString f', - "## arg", - prettyString arg, - "## sv", - show sv, - "## ret sv", - show $ autoMapSV (autoMap am) sv - ] - pure (apply_e, autoMapSV (autoFrame am) sv) --- pure (apply_e, sv) + apply_e = mkApply f' [(argext, mempty, arg')] callret + pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ @@ -1017,11 +984,6 @@ updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e -autoMapSV :: Shape Size -> StaticVal -> StaticVal -autoMapSV shape (Dynamic t) = - Dynamic $ arrayOfWithAliases (diet t) shape t -autoMapSV _ sv = sv - defuncApply :: Exp -> NE.NonEmpty ((Maybe VName, AutoMap), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) @@ -1037,39 +999,10 @@ defuncApply f args appres loc = do _ -> do let fname = liftedName 0 f (argtypes, _) = unfoldFunType $ typeOf f - (app, app_sv) <- - fmap (first $ updateReturn appres) $ - foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes - - let (p_ts, _) = unfoldFunType $ typeOf f - arg_ts = typeOf . snd <$> args - -- am_dims = zipWith typeShapePrefix (NE.toList arg_ts) p_ts - -- ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) am_dims - ams = NE.toList $ autoMap . snd . fst <$> args - ret_am = maximumBy (\x y -> shapeRank x `compare` shapeRank y) ams - debugTraceM $ - unlines - [ "## defuncApply", - "## f", - prettyString f, - "## args", - prettyString $ snd <$> args, - "## appres", - show appres, - "## app", - prettyString app, - "## app_sv", - show app_sv, - "## f type", - prettyString $ typeOf f, - "## arg types", - prettyString $ typeOf . snd <$> args, - "## ret_am", - prettyString ret_am - ] - pure (app, app_sv) + fmap (first $ updateReturn appres) $ + foldM (defuncApplyArg fname) (f', f_sv) $ + NE.zip args $ + NE.tails argtypes where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. @@ -1223,7 +1156,7 @@ matchPatSV (PatConstr c1 _ ps _) (Dynamic (Scalar (Sum fs))) else Nothing | otherwise = error $ "matchPatSV: missing constructor in type: " ++ prettyString c1 -matchPatSV pat (Dynamic t@(Scalar Record {})) = matchPatSV pat $ svFromType t +matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType $ toParam Observe t matchPatSV pat sv = error $ diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 21f559c00c..b5684552e4 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -8,13 +8,11 @@ module Futhark.Internalise.Exps (transformProg) where import Control.Monad import Control.Monad.Reader import Data.Bifunctor -import Data.Either import Data.Foldable (toList) -import Data.List (elemIndex, find, intercalate, intersperse, maximumBy, transpose, zip4) +import Data.List (elemIndex, find, intercalate, intersperse, transpose) import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M -import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Futhark.IR.SOACS as I hiding (stmPat) @@ -348,13 +346,12 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = let subst = map (,E.ExpSubst (E.sizeFromInteger 0 mempty)) ext et' = E.applySubst (`lookup` subst) et internaliseExp desc (E.Hole (Info et') loc) - (FunctionName qfname, argsam) -> do + (FunctionName qfname, args) -> do -- Argument evaluation is outermost-in so that any existential sizes -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e - (args, ams) = unzip argsam - args_am_desc = zip3 args ams (repeat (nameToString fname ++ "_arg")) + arg_desc = nameToString fname ++ "_arg" -- Some functions are magical (overloaded) and we handle that here. case () of @@ -364,21 +361,20 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do - withAutoMap args_am_desc $ \args' -> do - let prepareArg ((arg, _), am, _) arg' = - (E.toStruct $ E.stripArray (E.shapeRank $ autoMap am) (E.typeOf arg), arg') - internalise $ zipWith prepareArg args_am_desc args' + let prepareArg (arg, _) = + (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg + internalise =<< mapM prepareArg args | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, - Just (rettype, _) <- M.lookup fname I.builtInFunctions -> - withAutoMap args_am_desc $ \args' -> do - let tag ses = [(se, I.Observe) | se <- ses] - let args'' = concatMap tag args' - letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) + Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do + let tag ses = [(se, I.Observe) | se <- ses] + args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) + let args'' = concatMap tag args' + letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do - withAutoMap args_am_desc $ \args' -> do - funcall desc qfname (concat args') loc + args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) + funcall desc qfname args' loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body internaliseAppExp _ _ (E.LetFun ofname _ _ _) = @@ -879,255 +875,6 @@ internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) --- | Internalization of 'AutoMap'-annotated applications. --- --- Each application @f x@ has an annotation with @AutoMap R M F@ where --- @R, M, F@ are the autorep, automap, and frame shapes, --- respectively. --- --- The application @f x@ will have type @F t@ for some @t@, i.e. @(f --- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely --- it is the total accumulated shape that is due to implicit maps. --- Another way of thinking about that is that @|F|@ is is the level --- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ --- then we know that @f x@ implicitly stands for --- --- > map (\x' -> map (\x'' -> f x'') x') x --- --- For an application with a non-empty autorep annotation, the frame --- tells about how many dimensions of the replicate can be eliminated. --- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: --- --- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} --- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} --- --- All replicated arguments are pushed down the auto-map nest. Each --- time a replicated argument is pushed down a level of an --- automap-nest, one fewer replicates is needed (i.e., the outermost --- dimension of @R@ can be dropped). Replicated arguments are pushed --- down the nest until either 1) the bottom of the nest is encountered --- or 2) no replicate dimensions remain. For example, in the second --- application above @R@ = @F@, so we can push the replicated argument --- down two levels. Since each level effectively removes a dimension --- of the replicate, no replicates will be required: --- --- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] --- --- The number of replicates that are actually required is given by --- max(|R| - |F|, 0). --- --- An expression's "true level" is the level at which that expression --- will appear in the automap-nest. The bottom of a mapnest is level 0. --- --- * For annotations with @R = mempty@, the true level is @|F|@. --- * For annotations with @M = mempty@, the true level is @|F| - |R|@. --- --- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) --- will be required at the bottom of the mapnest. --- --- Note that replicates can only appear at the bottom of a mapnest; any --- expression of the form --- --- > map (\ls x' rs -> e) (replicate x) --- --- can always be written as --- --- > map (\ls rs -> e[x' -> x]) --- --- Let's look at another example. Consider (with exact sizes omitted for brevity) --- --- > f : a -> a -> a -> []a -> [][][]a -> a --- > xss : [][]a --- > ys : []a --- > zsss : [][][]a --- > w : a --- > vss : [][]a --- --- and the application --- --- > f xss ys zsss w vss --- --- which will have the following annotations --- --- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) --- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) --- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) --- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) --- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) --- --- This will yield the following mapnest. --- --- > map (\zss -> --- > map (\xs zs vs -> --- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss --- --- Let's see how we'd construct this mapnest from the annotations. We construct --- the nest bottom-up. We have: --- --- Application | True level --- --------------------------- --- (1) | |[][]| = 2 --- (2) | |[][]| - |[]| = 1 --- (3) | |[][][]| = 3 --- (4) | |[][][]| - |[][][][]| = -1 --- (5) | |[][][]| - |[]| = 2 --- --- We start at level 0. --- * Any argument with a negative true level of @-n@ will be replicated @n@ times; --- the exact shapes can be found by removing the @F@ postfix from @R@, --- i.e. @R = shapes_to_rep_by <> F@. --- * Any argument with a 0 true level will be included. --- * For any argument @arg@ with a positive true level, we construct a new parameter --- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) --- removed. --- --- Following the rules above, @w@ will be replicated once. For the remaining arguments, --- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes --- --- > f x y z (replicate w) v --- --- At level l > 0: --- * There are no replicates. --- * Any argument with l true level will be included verbatim. --- * Any argument with true level > l will have a new parameter constructed for it, --- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. --- * We surround the previous level with a map that binds that levels' new parameters --- and is passed the current levels' arguments. --- --- Following the above recipe for level 1, we create parameters --- @xs : []a, zs : []a, vs :[]a@ and obtain --- --- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs --- --- This process continues until the level is greater than the maximum --- true level of any application, at which we terminate. -type Level = Int - -data AutoMapArg = AutoMapArg - { amArgs :: [VName] - } - deriving (Show) - -data AutoMapParam = AutoMapParam - { amParams :: [LParam SOACS], - amMapDim :: SubExp - } - deriving (Show) - -withAutoMap :: - [((E.Exp, Maybe VName), AutoMap, String)] -> - ([[SubExp]] -> InternaliseM [SubExp]) -> - InternaliseM [SubExp] -withAutoMap args_am func = do - (param_maps, arg_maps) <- - unzip . reverse - <$> mapM buildArgMap (reverse args_am) - let param_map = M.unionsWith (<>) $ (fmap . fmap) pure param_maps - arg_map = M.unionsWith (<>) $ (fmap . fmap) pure arg_maps - buildMapNest param_map arg_map $ maximum $ M.keys arg_map - where - buildMapNest _ arg_map 0 = - func $ map (map I.Var . amArgs) $ arg_map M.! 0 - buildMapNest param_map arg_map l = - case map amMapDim $ param_map M.! l of - [] -> buildMapNest param_map arg_map (l - 1) - (map_dim : _) -> do - let params = map amParams $ param_map M.! l - args = map amArgs $ arg_map M.! l - - reshaped_args <- - forM (concat args) $ \argvn -> do - arg_t <- subExpType $ I.Var argvn - letExp "reshaped" $ - I.BasicOp $ - I.Reshape - I.ReshapeCoerce - (reshapeOuter (I.Shape [map_dim]) 1 $ I.arrayShape arg_t) - argvn - - letValExp' - "automap" - . Op - . Screma map_dim reshaped_args - . mapSOAC - =<< mkLambda - (concat params) - ( subExpsRes <$> buildMapNest param_map arg_map (l - 1) - ) - - buildArgMap :: - ((E.Exp, Maybe VName), AutoMap, String) -> - InternaliseM (M.Map Level AutoMapParam, M.Map Level AutoMapArg) - buildArgMap (arg, am, arg_desc) = do - ses <- internaliseArg arg_desc arg - arg_vnames <- mapM (letExp "" <=< eSubExp) ses - ts <- mapM subExpType ses - foldM (mkArgsAndParams arg_vnames ts) (mempty, mempty) $ - reverse [0 .. trueLevel am] - where - mkArgsAndParams arg_vnames ts (p_map, a_map) l - | l == 0 = do - let as = - maybe - arg_vnames - ( map I.paramName - . amParams - ) - (p_map M.!? 1) - ses <- mkBottomArgs as ts - pure (p_map, M.insert 0 (AutoMapArg ses) a_map) - | l == trueLevel am = do - ps <- mkParams arg_vnames ts l - d <- outerDim am l - pure - ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg arg_vnames) a_map - ) - | l < trueLevel am && l > 0 = do - ps <- mkParams arg_vnames ts l - d <- outerDim am l - let as = - map I.paramName $ - amParams $ - p_map M.! (l + 1) - pure - ( M.insert l (AutoMapParam ps d) p_map, - M.insert l (AutoMapArg as) a_map - ) - | otherwise = error "" - - mkParams _ ts level = - forM ts $ \t -> - newParam ("p_" <> arg_desc) $ argType (level - 1) am t - mkBottomArgs arg_vnames ts = do - rep_shape <- internaliseShape $ autoRep am `E.shapePrefix` autoFrame am - if I.shapeRank rep_shape > 0 - then - concat - <$> mapM - ( letValExp "autorep" - . BasicOp - . Replicate rep_shape - . I.Var - ) - arg_vnames - else pure arg_vnames - - internaliseShape :: E.Shape Size -> InternaliseM I.Shape - internaliseShape = - fmap I.Shape . mapM (internaliseExp1 "") . E.shapeDims - - trueLevel :: AutoMap -> Int - trueLevel am - | autoMap am == mempty = max 0 $ E.shapeRank (autoFrame am) - E.shapeRank (autoRep am) - | otherwise = E.shapeRank $ autoFrame am - - outerDim :: AutoMap -> Int -> InternaliseM SubExp - outerDim am level = - internaliseExp1 "" $ (!! (trueLevel am - level)) $ E.shapeDims $ autoFrame am - - argType level am = I.stripArray (trueLevel am - level) - generateCond :: E.Pat StructType -> [I.SubExp] -> @@ -1715,14 +1462,14 @@ data Function | FunctionHole SrcLoc deriving (Show) -findFuncall :: E.AppExp -> (Function, [((E.Exp, Maybe VName), AutoMap)]) +findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where - onArg (Info (argext, am), e) = ((e, argext), am) + onArg (Info (argext, _), e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index aada3924c0..2617f95b8b 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -535,7 +535,7 @@ transformAppExp (Loop sparams pat e1 form body loc) res = do (pat_sizes, pat'') <- sizesForPat pat' res' <- transformAppRes res pure $ AppExp (Loop (sparams' ++ pat_sizes) pat'' e1' form' body' loc) (Info res') -transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, am1)) (e2, Info (d2, am2)) loc) res = do +transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, _)) (e2, Info (d2, _)) loc) res = do (AppRes ret ext) <- transformAppRes res fname' <- transformFName loc fname (toStruct t) e1' <- transformExp e1 @@ -570,8 +570,8 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, Info (d1, am1)) (e2, Info (d2, a where applyOp ret ext fname' x y = mkApply - (mkApply fname' [(d1, am1, x)] (AppRes ret mempty)) - [(d2, am2, y)] + (mkApply fname' [(d1, mempty, x)] (AppRes ret mempty)) + [(d2, mempty, y)] (AppRes ret ext) makeVarParam arg = do @@ -664,27 +664,27 @@ transformExp (Lambda params e0 decl tp loc) = do transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do - let (Info (xp, xtype, xargext, xam), Info (yp, ytype)) = arg + let (Info (xp, xtype, xargext, _), Info (yp, ytype)) = arg e' <- transformExp e desugarBinOpSection fname (Just e') Nothing t - (xp, xtype, xargext, xam) - (yp, ytype, Nothing, mempty) + (xp, xtype, xargext) + (yp, ytype, Nothing) (rettype, retext) loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do - let (Info (xp, xtype), Info (yp, ytype, yargext, yam)) = arg + let (Info (xp, xtype), Info (yp, ytype, yargext, _)) = arg e' <- transformExp e desugarBinOpSection fname Nothing (Just e') t - (xp, xtype, Nothing, mempty) - (yp, ytype, yargext, yam) + (xp, xtype, Nothing) + (yp, ytype, yargext) (rettype, []) loc transformExp (ProjectSection fields (Info t) loc) = do @@ -735,12 +735,12 @@ desugarBinOpSection :: Maybe Exp -> Maybe Exp -> StructType -> - (PName, ParamType, Maybe VName, AutoMap) -> - (PName, ParamType, Maybe VName, AutoMap) -> + (PName, ParamType, Maybe VName) -> + (PName, ParamType, Maybe VName) -> (ResRetType, [VName]) -> SrcLoc -> MonoM Exp -desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, yext, yam) (RetType dims rettype, retext) loc = do +desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do t' <- transformType t op <- transformFName loc fname $ toStruct t (v1, wrap_left, e1, p1) <- makeVarParam e_left =<< transformType xtype @@ -748,7 +748,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, ye let apply_left = mkApply op - [(xext, xam, e1)] + [(xext, mempty, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc @@ -757,7 +757,7 @@ desugarBinOpSection fname e_left e_right t (xp, xtype, xext, xam) (yp, ytype, ye rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ - mkApply apply_left [(yext, yam, e2)] + mkApply apply_left [(yext, mempty, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ From 8d88c9582154ac4a22c83e268e8aec2f8acd6c45 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 28 Feb 2024 12:08:44 -0800 Subject: [PATCH 144/296] Undo AUTOMAP-handling in `FullNormalise.hs`. --- src/Futhark/Internalise/FullNormalise.hs | 46 ++++++++---------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 17841c9f53..a3dca1f8bd 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -203,13 +203,13 @@ getOrdering final (Lambda params body mte ret loc) = do nameExp final $ Lambda params body' mte ret loc getOrdering _ (OpSection qn ty loc) = pure $ Var qn ty loc -getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, xam), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do +getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, _), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do x <- getOrdering False e yn <- newNameFromString "y" let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = - mkApply (Var op ty mempty) [(xext, xam, x), (Nothing, mempty, y)] $ + mkApply (Var op ty mempty) [(xext, mempty, x), (Nothing, mempty, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where @@ -217,12 +217,12 @@ getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext, xam), Info (yp, yty | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing -getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, yam)) (Info (RetType dims ret)) loc) = do +getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext, _)) (Info (RetType dims ret)) loc) = do xn <- newNameFromString "x" y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret - body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, yam, y)] $ AppRes (toStruct ret') [] + body = mkApply (Var op ty mempty) [(Nothing, mempty, x), (yext, mempty, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn @@ -298,41 +298,25 @@ getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do While e -> While <$> transformBody e body' <- transformBody body nameExp final $ AppExp (Loop sizes pat einit' form' body' loc) resT -getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, elam)) (er, Info (erp, eram)) loc) (Info resT)) = do +getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info (elp, _)) (er, Info (erp, _)) loc) (Info resT)) = do -- Rewrite short-circuiting boolean operators on scalars to explicit -- if-then-else. Automapped cases are turned into applications of -- intrinsic functions. expr' <- case (isOr, isAnd) of - (True, _) - | elam == mempty, - eram == mempty -> do - el' <- naming "or_lhs" $ getOrdering True el - er' <- naming "or_rhs" $ transformBody er - pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) - | otherwise -> do - el' <- naming "or_lhs" $ getOrdering False el - er' <- naming "or_rhs" $ getOrdering False er - pure $ mkApply orop [(elp, elam, el'), (erp, eram, er')] resT - (_, True) - | elam == mempty, - eram == mempty -> do - el' <- naming "and_lhs" $ getOrdering True el - er' <- naming "and_rhs" $ transformBody er - pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) - | otherwise -> do - el' <- naming "and_lhs" $ getOrdering False el - er' <- naming "and_rhs" $ getOrdering False er - pure $ mkApply andop [(elp, elam, el'), (erp, eram, er')] resT - _ -> do + (True, _) -> do + el' <- naming "or_lhs" $ getOrdering True el + er' <- naming "or_rhs" $ transformBody er + pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) + (_, True) -> do + el' <- naming "and_lhs" $ getOrdering True el + er' <- naming "and_rhs" $ transformBody er + pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) + (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er - pure $ mkApply (Var op opT oloc) [(elp, elam, el'), (erp, eram, er')] resT + pure $ mkApply (Var op opT oloc) [(elp, mempty, el'), (erp, mempty, er')] resT nameExp final expr' where - bool = Scalar $ Prim Bool - opt = foldFunType [bool, bool] $ RetType [] bool - andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty - orop = Var (qualName (intrinsicVar "||")) (Info opt) mempty isOr = baseName (qualLeaf op) == "||" isAnd = baseName (qualLeaf op) == "&&" getOrdering final (AppExp (LetWith (Ident dest dty dloc) (Ident src sty sloc) slice e body loc) _) = do From de5cf38fa8e36e7145db8d080346158ec3292b9e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 14:54:36 +0100 Subject: [PATCH 145/296] Use proper type here. --- src/Language/Futhark/TypeChecker/Terms/Loop.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 7cba8af7e8..51d1c8ceba 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -221,14 +221,16 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do case form of For i uboundexp -> do uboundexp' <- checkExp uboundexp - bindingIdent i . bindingPat [] mergepat merge_t $ + it <- expType uboundexp' + let i' = i {identType = Info it} + bindingIdent i' . bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' pure ( sparams, mergepat'', - For i uboundexp', + For i' uboundexp', loopbody' ) ForIn xpat e -> do From dcfdf1e5b373e0a8625862499c864e74f57c1161 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 14:56:09 +0100 Subject: [PATCH 146/296] Also update type here. --- src/Language/Futhark/TypeChecker/Terms.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 25c6dcb0c8..68547cb024 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -363,8 +363,8 @@ checkExp (RecordLit fs loc) = where checkField (RecordFieldExplicit f e rloc) = RecordFieldExplicit f <$> checkExp e <*> pure rloc - checkField (RecordFieldImplicit name info rloc) = - pure $ RecordFieldImplicit name info rloc + checkField (RecordFieldImplicit name (Info t) rloc) = + RecordFieldImplicit name <$> (Info <$> replaceTyVars rloc t) <*> pure rloc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use From 9c869ab0b9d48477deda0271e0a4589532dca4ff Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 16:38:13 +0100 Subject: [PATCH 147/296] Fix handling of overloaded type variables. --- .../Futhark/TypeChecker/Constraints.hs | 25 +++++++------------ src/Language/Futhark/TypeChecker/Terms2.hs | 25 +++++++++++++------ src/Language/Futhark/TypeChecker/Types.hs | 18 +++++++++++++ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 0725d50418..10206b0bb8 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -22,6 +22,7 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Types (substTyVars) type SVar = VName @@ -99,21 +100,13 @@ newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} initialState :: TyVars -> SolverState initialState tyvars = SolverState $ M.map (uncurry TyVarUnsol) tyvars -substTyVars :: (Monoid u) => M.Map TyVar TyVarSol -> TypeBase SComp u -> TypeBase SComp u -substTyVars m t@(Scalar (TypeVar u (QualName qs v) args)) = +substTyVar :: (Monoid u) => M.Map TyVar TyVarSol -> VName -> Maybe (TypeBase SComp u) +substTyVar m v = case M.lookup v m of - Just (TyVarLink v') -> - substTyVars m $ Scalar $ TypeVar u (QualName qs v') args - Just (TyVarSol _ t') -> second (const mempty) $ substTyVars m t' - Just (TyVarUnsol {}) -> t - Nothing -> t -substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt -substTyVars m (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars m) fs -substTyVars m (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars m) cs -substTyVars m (Scalar (Arrow u pname d t1 (RetType ext t2))) = - Scalar $ Arrow u pname d (substTyVars m t1) $ RetType ext $ substTyVars m t2 `setUniqueness` uniqueness t2 -substTyVars m (Array u shape elemt) = - arrayOfWithAliases u shape $ substTyVars m $ Scalar elemt + Just (TyVarLink v') -> substTyVar m v' + Just (TyVarSol _ t') -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (TyVarUnsol {}) -> Nothing + Nothing -> Nothing -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand @@ -127,7 +120,7 @@ solution s = ) where mkSubst (TyVarSol _lvl t) = - Just $ Right $ first (const ()) $ substTyVars (solverTyVars s) t + Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t mkSubst (TyVarLink v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) @@ -143,7 +136,7 @@ newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} occursCheck :: VName -> Type -> SolveM () occursCheck v tp = do vars <- gets solverTyVars - let tp' = substTyVars vars tp + let tp' = substTyVars (substTyVar vars) tp when (v `S.member` typeVars tp') . throwError . docText $ "Occurs check: cannot instantiate" <+> prettyName v diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ccba90ab07..f65bbb1811 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -640,7 +640,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs debugTraceM $ - unlines $ + unlines [ "## checkApplyOne", "## fname", prettyString fname, @@ -1123,12 +1123,12 @@ checkExp (Coerce e te NoInfo loc) = do ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc -doDefaults :: +doDefault :: S.Set VName -> VName -> Either [PrimType] (TypeBase () NoUniqueness) -> TermM (TypeBase () NoUniqueness) -doDefaults tyvars_at_toplevel v (Left pts) +doDefault tyvars_at_toplevel v (Left pts) | Signed Int32 `elem` pts = do when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." @@ -1145,7 +1145,18 @@ doDefaults tyvars_at_toplevel v (Left pts) "Add a type annotation to disambiguate the type." where usage = mkUsage NoLoc "overload" -doDefaults _ _ (Right t) = pure t +doDefault _ _ (Right t) = pure t + +-- | Apply defaults on otherwise ambiguous types. This may result in +-- some type variables becoming known, so we have to perform +-- substitutions on the RHS of the substitutions afterwards. +doDefaults :: + S.Set VName -> + M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -> + TermM (M.Map TyVar (TypeBase () NoUniqueness)) +doDefaults tyvars_at_toplevel substs = do + substs' <- M.traverseWithKey (doDefault tyvars_at_toplevel) substs + pure $ M.map (substTyVars (`M.lookup` substs')) substs' checkValDef :: ( VName, @@ -1187,8 +1198,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do solutions <- forM cts_tyvars' $ - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) - . uncurry solve + bitraverse pure (traverse (doDefaults mempty)) . uncurry solve forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> debugTraceM $ @@ -1213,6 +1223,5 @@ checkSingleExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars solution <- - bitraverse pure (traverse (M.traverseWithKey (doDefaults mempty))) $ - solve cts tyvars + bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars pure (solution, e') diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 22fa6fc5a1..b89802c5e0 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -8,6 +8,7 @@ module Language.Futhark.TypeChecker.Types TypeSubs, Substitutable (..), substTypesAny, + substTyVars, -- * Witnesses mustBeExplicitInType, @@ -531,6 +532,23 @@ substTypesAny lookupSubst ot = toAny d = d in first toAny ot' +-- | Substitution without caring about sizes. +substTyVars :: (Monoid u) => (VName -> Maybe (TypeBase d NoUniqueness)) -> TypeBase d u -> TypeBase d u +substTyVars f t@(Scalar (TypeVar u (QualName qs v) args)) = + case f v of + Just t' -> second (const mempty) $ substTyVars f t' + Nothing -> t +substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt +substTyVars f (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars f) fs +substTyVars f (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars f) cs +substTyVars f (Scalar (Arrow u pname d t1 (RetType ext t2))) = + Scalar $ + Arrow u pname d (substTyVars f t1) $ + RetType ext $ + substTyVars f t2 `setUniqueness` uniqueness t2 +substTyVars f (Array u shape elemt) = + arrayOfWithAliases u shape $ substTyVars f $ Scalar elemt + -- Note [AnySize] -- -- Consider a program: From 837f3cd57121906ddfaf428b7c66fe57c92ff153 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 17:02:16 +0100 Subject: [PATCH 148/296] Add checkSizeExp to Terms2. --- src/Language/Futhark/TypeChecker/Terms.hs | 5 ++--- src/Language/Futhark/TypeChecker/Terms2.hs | 24 ++++++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 68547cb024..e953ae3c00 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1063,16 +1063,14 @@ checkOneExp e = do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp checkSizeExp e = do - (maybe_tysubsts, e') <- Terms2.checkSingleExp e + (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' - let t = typeOf e'' when (hasBinding e'') $ typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ "Size expression with binding is forbidden." - unify (mkUsage e'' "Size expression") t (Scalar (Prim (Signed Int64))) normTypeFully e'' -- Verify that all sum type constructors and empty array literals have @@ -1657,6 +1655,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + debugTraceM $ unlines [unlines $ map show $ M.toList tysubsts, prettyString body'] let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index f65bbb1811..b56f534184 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -38,6 +38,7 @@ module Language.Futhark.TypeChecker.Terms2 ( checkValDef, checkSingleExp, + checkSizeExp, Solution, ) where @@ -442,8 +443,8 @@ patLitMkType (PatLitFloat _) loc = patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v -checkSizeExp :: ExpBase NoInfo VName -> TermM Exp -checkSizeExp e = do +checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp +checkSizeExp' e = do e' <- checkExp e ctEq (expType e') (Scalar (Prim (Signed Int64))) pure e' @@ -496,7 +497,7 @@ checkPat' (RecordPat fs loc) NoneInferred = <$> traverse (`checkPat'` NoneInferred) (M.fromList fs) <*> pure loc checkPat' (PatAscription p t loc) maybe_outer_t = do - (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp t + (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp' t -- Uniqueness kung fu to make the Monoid(mempty) instance give what -- we expect. We should perhaps stop being so implicit. @@ -775,7 +776,7 @@ checkRetDecl :: TermM (Maybe (TypeExp Exp VName)) checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do - (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te ctEq (expType body) (toType st) pure $ Just te' @@ -1114,12 +1115,12 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do -- checkExp (Ascript e te loc) = do e' <- checkExp e - (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te ctEq (expType e') (toType st) pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e - (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp te + (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te ctEq (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc @@ -1225,3 +1226,14 @@ checkSingleExp e = runTermM $ do solution <- bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars pure (solution, e') + +-- | Type-check a single size expression in isolation. This expression may +-- turn out to be polymorphic, in which case it is unified with i64. +checkSizeExp :: ExpBase NoInfo VName -> TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) +checkSizeExp e = runTermM $ do + e' <- checkSizeExp' e + cts <- gets termConstraints + tyvars <- gets termTyVars + solution <- + bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars + pure (solution, e') From 12d21a47e5e6ad63b6784611df89acd6f360f6c1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 29 Feb 2024 22:47:09 +0100 Subject: [PATCH 149/296] Fixes to updating of types. --- src/Language/Futhark/TypeChecker/Terms.hs | 1 - .../Futhark/TypeChecker/Terms/Monad.hs | 1 + src/Language/Futhark/TypeChecker/Terms/Pat.hs | 30 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index e953ae3c00..5eae32ac6b 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1655,7 +1655,6 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do - debugTraceM $ unlines [unlines $ map show $ M.toList tysubsts, prettyString body'] let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained (tparams', params'', retdecl'', RetType dims rettype', body'') <- checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 3e23ceec2e..2153eb2751 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -25,6 +25,7 @@ module Language.Futhark.TypeChecker.Terms.Monad constrain, newArrayType, allDimsFreshInType, + instTyVars, replaceTyVars, updateTypes, Names, diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 17aef0e1a2..4485c20f4c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -15,6 +15,7 @@ import Data.List (find, isPrefixOf, sort) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Futhark.Util import Futhark.Util.Pretty hiding (group, space) import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV) @@ -104,9 +105,9 @@ bindingIdent ident = binding [ident] checkPat' :: [(SizeBinder VName, QualName VName)] -> - Pat (TypeBase Size u) -> + Pat ParamType -> Inferred ParamType -> - TermTypeM (Pat (TypeBase Size u)) + TermTypeM (Pat ParamType) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = @@ -115,16 +116,14 @@ checkPat' _ (Id name (Info t) loc) NoneInferred = do t' <- replaceTyVars loc t pure $ Id name (Info t') loc checkPat' _ (Id name (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc t1 - unify (mkUsage loc "id") (toStruct t) (toStruct t2) - pure $ Id name (Info t) loc + t' <- instTyVars loc [] (first (const ()) t1) t2 + pure $ Id name (Info t') loc checkPat' _ (Wildcard (Info t) loc) NoneInferred = do t' <- replaceTyVars loc t pure $ Wildcard (Info t') loc checkPat' _ (Wildcard (Info t1) loc) (Ascribed t2) = do - t <- replaceTyVars loc t1 - unify (mkUsage loc "wildcard") (toStruct t) (toStruct t2) - pure $ Wildcard (Info t) loc + t' <- instTyVars loc [] (first (const ()) t1) t2 + pure $ Wildcard (Info t') loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = @@ -185,9 +184,9 @@ checkPat' sizes (PatConstr n info ps loc) _ = do checkPat :: [(SizeBinder VName, QualName VName)] -> - Pat (TypeBase Size u) -> + Pat ParamType -> Inferred StructType -> - (Pat (TypeBase Size u) -> TermTypeM a) -> + (Pat ParamType -> TermTypeM a) -> TermTypeM a checkPat sizes p t m = do p' <- @@ -210,14 +209,15 @@ bindingPat :: [SizeBinder VName] -> Pat (TypeBase Size u) -> StructType -> - (Pat (TypeBase Size u) -> TermTypeM a) -> + (Pat ParamType -> TermTypeM a) -> TermTypeM a bindingPat sizes p t m = do substs <- mapM mkSizeSubst sizes - checkPat substs p (Ascribed t) $ \p' -> binding (patIdents (fmap toStruct p')) $ - case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of - [] -> m p' - size : _ -> unusedSize size + checkPat substs (fmap (toParam Observe) p) (Ascribed t) $ \p' -> + binding (patIdents (fmap toStruct p')) $ + case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of + [] -> m p' + size : _ -> unusedSize size where mkSizeSubst v = do v' <- newID $ baseName $ sizeName v From c6aaf9ca507368589e5ff904668e06483c8e5e83 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 15:36:00 +0100 Subject: [PATCH 150/296] Fix unification of abstract types. --- src/Language/Futhark/TypeChecker/Constraints.hs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 10206b0bb8..fc4f1fb5f8 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -166,6 +166,14 @@ linkTyVar v t = do unify :: Type -> Type -> Maybe [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Just [] +unify + (Scalar (TypeVar _ (QualName _ v1) targs1)) + (Scalar (TypeVar _ (QualName _ v2) targs2)) + | v1 == v2 = + Just $ mapMaybe f $ zip targs1 targs2 + where + f (TypeArgType t1, TypeArgType t2) = Just (t1, t2) + f _ = Nothing unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = Just [(t1a, t2a), (t1r', t2r')] where @@ -223,13 +231,9 @@ solveCt ct = (Scalar (TypeVar _ (QualName [] v1) []), t2') | Just lvl <- flexible v1 -> subTyVar v1 lvl t2' - | otherwise -> - bad (t1', Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- flexible v2 -> subTyVar v2 lvl t1' - | otherwise -> - bad (t1', t2') -> case unify t1' t2' of Nothing -> bad Just eqs -> mapM_ solveCt' eqs From b6d2e1b383b0cf4cafc23db411430c64e2eb10fe Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 15:44:29 +0100 Subject: [PATCH 151/296] Also do the AUTOMAP on size expressions. --- src/Language/Futhark/TypeChecker/Terms2.hs | 83 ++++++++++++---------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b56f534184..cc1f611129 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1174,47 +1174,48 @@ checkValDef :: [Exp] ) checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do - bindParams tparams params $ \params' -> do - body' <- checkExp body + (params', body', retdecl') <- + bindParams tparams params $ \params' -> do + body' <- checkExp body + retdecl' <- checkRetDecl body' retdecl + pure (params', body', retdecl') - retdecl' <- checkRetDecl body' retdecl + cts <- gets termConstraints - cts <- gets termConstraints + tyvars <- gets termTyVars - tyvars <- gets termTyVars + debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + debugTraceM $ + unlines + [ "## cts:", + unlines $ map prettyString cts, + "## body:", + prettyString body', + "## tyvars:", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + ] + + (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' + solutions <- + forM cts_tyvars' $ + bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + + forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> debugTraceM $ unlines - [ "## cts:", - unlines $ map prettyString cts, - "## body:", - prettyString body', - "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars':", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] - (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' - - solutions <- - forM cts_tyvars' $ - bitraverse pure (traverse (doDefaults mempty)) . uncurry solve - - forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> - debugTraceM $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars':", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - - pure (solutions, params', retdecl', bodys') + pure (solutions, params', retdecl', bodys') checkSingleExp :: ExpBase NoInfo VName -> @@ -1229,11 +1230,21 @@ checkSingleExp e = runTermM $ do -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. -checkSizeExp :: ExpBase NoInfo VName -> TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) +checkSizeExp :: + ExpBase NoInfo VName -> + TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints tyvars <- gets termTyVars - solution <- - bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars - pure (solution, e') + + (cts_tyvars', es') <- unzip <$> rankAnalysis (srclocOf e) cts tyvars e' + + solutions <- + forM cts_tyvars' $ + bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + + case (solutions, es') of + ([solution], [e'']) -> + pure (solution, e'') + _ -> pure (Left "Ambiguous size expression", e') From 45a5c448578c5a1f03935de811ae400d89b202e8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 16:45:21 +0100 Subject: [PATCH 152/296] Fix instantiation of parametric abstract types. --- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 2153eb2751..086580745d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -411,6 +411,14 @@ instTyVars loc names orig_t1 orig_t2 = do (Array _ (Shape (d : ds2)) t2) = arrayOfWithAliases u (Shape [d]) <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) + f + (Scalar (TypeVar u v1 targs1)) + (Scalar (TypeVar _ _ targs2)) = + Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 + where + g (TypeArgType t1) (TypeArgType t2) = + TypeArgType <$> f t1 t2 + g _ targ = pure targ f t1 t2 = do let mkNew = fst <$> lift (allDimsFreshInType (mkUsage loc "instantiation") Nonrigid "dv" t1) @@ -538,11 +546,6 @@ lookupVar loc qn@(QualName qs name) inst_t = do replaceTyVars loc inst_t Just OverloadedF {} -> replaceTyVars loc inst_t - where - instOverloaded argtype pts rt = - ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, - maybe (toStruct argtype) (Scalar . Prim) rt - ) onFailure :: Checking -> TermTypeM a -> TermTypeM a onFailure c = local $ \env -> env {termChecking = Just c} From a3a50945e72882c5de16bc2ed7374f0496f05844 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 1 Mar 2024 17:06:56 +0100 Subject: [PATCH 153/296] Must match. --- src/Language/Futhark/TypeChecker/Terms/Monad.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 086580745d..018a3e920a 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -60,6 +60,7 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified +import Futhark.Util import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals @@ -413,8 +414,9 @@ instTyVars loc names orig_t1 orig_t2 = do <$> f (arrayOf (Shape ds1) (Scalar t1)) (arrayOf (Shape ds2) (Scalar t2)) f (Scalar (TypeVar u v1 targs1)) - (Scalar (TypeVar _ _ targs2)) = - Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 + (Scalar (TypeVar _ _ targs2)) + | length targs1 == length targs2 = + Scalar . TypeVar u v1 <$> zipWithM g targs1 targs2 where g (TypeArgType t1) (TypeArgType t2) = TypeArgType <$> f t1 t2 @@ -461,7 +463,6 @@ instTypeScheme qn loc tparams scheme_t inferred = do let tp_names = map typeParamName $ filter isTypeParam tparams t' <- instTyVars loc tp_names inferred $ applySubst (`lookup` substs) scheme_t - pure (names, t') lookupQualNameEnv :: QualName VName -> TermTypeM TermScope From 989c75778003feb0a2509d4827d5a3425fae6ff7 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 1 Mar 2024 17:49:27 -0800 Subject: [PATCH 154/296] Fix `mri-q.fut`. --- tests/automap/mri-q.fut | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index eaed14333a..8fe26aded6 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -23,18 +23,17 @@ def main_am [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) (phiR: [numK]f32) (phiI: [numK]f32) - : ([numK]f32, [numX][numK]f32) = - let (phiMag : [numK]f32) = phiR * phiR + phiI * phiI - let (expArgs : [numX][numK]f32) = map3 (\(x_e : f32) (y_e : f32) (z_e : f32) -> + : ([numX]f32, [numX]f32) = + let phiMag = phiR * phiR + phiI * phiI + let expArgs = map3 (\x_e y_e z_e -> 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) x y z - in (phiMag, expArgs) - --let (qr : [numX]f32) = f32.sum (f32.cos expArgs * phiMag) -- [numx]f32 - --let (qi : [numX]f32) = f32.sum (f32.sin expArgs * phiMag) -- let (qi_10408: artificial₁₁₄_10524 ~ [M113_10523]f32) - --in (qr, qi) + let qr = f32.sum (f32.cos expArgs * phiMag) + let qi = f32.sum (f32.sin expArgs * phiMag) + in (qr, qi) ---entry main [numK][numX] --- (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) --- (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) --- (phiR: [numK]f32) (phiI: [numK]f32) = --- main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI +entry main [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) : bool = + main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI From 101285a0a4d907225786a55378e94fe4756bb1c2 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 1 Mar 2024 17:38:11 -0800 Subject: [PATCH 155/296] `debugTraceM` now takes a level. --- src/Futhark/Util.hs | 8 ++++---- src/Language/Futhark/Pretty.hs | 4 ++-- src/Language/Futhark/TypeChecker.hs | 2 +- src/Language/Futhark/TypeChecker/Rank.hs | 6 +++--- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 10 +++++----- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 6d41b1c7f2..6a97f25e0e 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -510,8 +510,8 @@ topologicalSort dep nodes = mapM_ sorting $ mapMaybe (depends_of node) nodes_idx modify $ bimap (node :) (IM.insert i False) --- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to at least 1. -debugTraceM :: (Monad m) => String -> m () -debugTraceM - | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 = traceM +-- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to to the appropriate level. +debugTraceM :: (Monad m) => Int -> String -> m () +debugTraceM level + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" level = traceM | otherwise = const $ pure () diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 3318fe11e4..8e0b2619d9 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -237,7 +237,7 @@ prettyAppExp :: (Eq vn, IsName vn, Annot f) => Int -> AppExpBase f vn -> Doc a prettyAppExp p (BinOp (bop, _) _ (x, xi) (y, yi) _) = case (unAnnot xi, unAnnot yi) of (Just (_, xam), Just (_, yam)) - | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 3 -> -- fix parens $ align $ prettyBinOp p bop x y "Δ" <+> pretty xam "Δ" <+> pretty yam _ -> prettyBinOp p bop x y @@ -322,7 +322,7 @@ prettyAppExp p (Apply f args _) = prettyArg (i, e) = case unAnnot i of Just (_, am) - | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 -> + | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 3 -> parens (prettyExp 10 e <+> "Δ" <+> pretty am) _ -> prettyExp 10 e diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 65fd30f220..70824ad148 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -714,7 +714,7 @@ checkValBind vb = do let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - debugTraceM $ unlines ["# Inferred:", prettyString vb'] + debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] pure ( mempty diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 186f8d0fb3..9672e9aa60 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -195,7 +195,7 @@ enumerateRankSols prog = solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] solveRankILP loc prog = do - debugTraceM $ + debugTraceM 3 $ unlines [ "## solveRankILP", prettyString prog @@ -203,9 +203,9 @@ solveRankILP loc prog = do case enumerateRankSols prog of [] -> typeError loc mempty "Rank ILP cannot be solved." rs -> do - debugTraceM "## rank maps" + debugTraceM 3 "## rank maps" forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> - debugTraceM $ + debugTraceM 3 $ unlines $ "\n## rank map " <> prettyString i : map prettyString (M.toList r) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 5eae32ac6b..ec30e59689 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -963,7 +963,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d then splitArrayAt (autoFrameRank am) <$> normTypeFully tp1 else splitArrayAt (autoFrameRank am) <$> normTypeFully argtype - debugTraceM $ + debugTraceM 3 $ unlines [ "## checkApply", "## fn", diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cc1f611129..1e8dbcefbb 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -640,8 +640,8 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do rhs = arrayOf (toShape (SVar m)) a ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs - debugTraceM $ - unlines + debugTraceM 3 $ + unlines $ [ "## checkApplyOne", "## fname", prettyString fname, @@ -1184,9 +1184,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do tyvars <- gets termTyVars - debugTraceM $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" + debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" - debugTraceM $ + debugTraceM 3 $ unlines [ "## cts:", unlines $ map prettyString cts, @@ -1203,7 +1203,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bitraverse pure (traverse (doDefaults mempty)) . uncurry solve forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> - debugTraceM $ + debugTraceM 3 $ unlines [ "## constraints:", unlines $ map prettyString cts', From 178a7a09e30b262be1b8263ca9295cd9af007d10 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Fri, 1 Mar 2024 10:57:46 -0800 Subject: [PATCH 156/296] Expand AUTOMAP annotations in normalization. --- src/Futhark/Internalise/Exps.hs | 13 + src/Futhark/Internalise/FullNormalise.hs | 344 ++++++++++++++++++++++- 2 files changed, 356 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index b5684552e4..11b261e425 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1624,6 +1624,19 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing + handleSOACs (lam : args) "map" = Just $ \desc -> do + let internaliseVName x = do + es <- map (BasicOp . SubExp) <$> internaliseExp "arg" x + concat <$> mapM (letValExp "arg") es + args' <- concat <$> mapM internaliseVName args + param_ts <- mapM (fmap (I.stripArray 1) . lookupType) args' + map_dim <- (head . I.shapeDims . I.arrayShape) <$> lookupType (head args') + lambda <- internaliseLambdaCoerce lam param_ts + letTupExp' + desc + $ Op + $ Screma map_dim args' + $ mapSOAC lambda handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index a3dca1f8bd..fe197a842c 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -23,11 +23,18 @@ module Futhark.Internalise.FullNormalise (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor +import Data.Functor.Identity +import Data.List (zip4) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M +import Data.Maybe import Data.Text qualified as T +import Debug.Trace import Futhark.MonadFreshNames +import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.Pretty +import Language.Futhark.Primitive (intValue) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types @@ -347,7 +354,7 @@ getOrdering final (AppExp (Match expr cs loc) resT) = do -- a complete separtion of states. transformBody :: (MonadFreshNames m) => Exp -> m Exp transformBody e = do - (e', pre_eval) <- runOrdering (getOrdering True e) + (e', pre_eval) <- runOrdering . getOrdering True =<< expandAMAnnotations e pure $ foldl f e' pre_eval where appRes = case e of @@ -366,3 +373,338 @@ transformValBind valbind = do transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg = mapM transformValBind + +--- | Expansion of 'AutoMap'-annotated applications. +--- +--- Each application @f x@ has an annotation with @AutoMap R M F@ where +--- @R, M, F@ are the autorep, automap, and frame shapes, +--- respectively. +--- +--- The application @f x@ will have type @F t@ for some @t@, i.e. @(f +--- x) : F t@. The frame @F@ is a prefix of the type of @f x@; namely +--- it is the total accumulated shape that is due to implicit maps. +--- Another way of thinking about that is that @|F|@ is is the level +--- of the automap-nest that @f x@ is in. For example, if @|F| = 2@ +--- then we know that @f x@ implicitly stands for +--- +--- > map (\x' -> map (\x'' -> f x'') x') x +--- +--- For an application with a non-empty autorep annotation, the frame +--- tells about how many dimensions of the replicate can be eliminated. +--- For example, @[[1,2],[3,4]] + 5@ will yield the following annotations: +--- +--- > ([[1,2],[3,4]] +) -- AutoMap {R = mempty, M = [2][2], F = [2][2]} +--- > (([[1,2],[3,4]] +) 5) -- AutoMap {R = [2][2], M = mempty, F = [2][2]} +--- +--- All replicated arguments are pushed down the auto-map nest. Each +--- time a replicated argument is pushed down a level of an +--- automap-nest, one fewer replicates is needed (i.e., the outermost +--- dimension of @R@ can be dropped). Replicated arguments are pushed +--- down the nest until either 1) the bottom of the nest is encountered +--- or 2) no replicate dimensions remain. For example, in the second +--- application above @R@ = @F@, so we can push the replicated argument +--- down two levels. Since each level effectively removes a dimension +--- of the replicate, no replicates will be required: +--- +--- > map (\xs -> map (\x -> f x'' 5) xs) [[1,2],[3,4]] +--- +--- The number of replicates that are actually required is given by +--- max(|R| - |F|, 0). +--- +--- An expression's "true level" is the level at which that expression +--- will appear in the automap-nest. The bottom of a mapnest is level 0. +--- +--- * For annotations with @R = mempty@, the true level is @|F|@. +--- * For annotations with @M = mempty@, the true level is @|F| - |R|@. +--- +--- If @|R| > |F|@ then actual replicates (namely @|R| - |F|@ of them) +--- will be required at the bottom of the mapnest. +--- +--- Note that replicates can only appear at the bottom of a mapnest; any +--- expression of the form +--- +--- > map (\ls x' rs -> e) (replicate x) +--- +--- can always be written as +--- +--- > map (\ls rs -> e[x' -> x]) +--- +--- Let's look at another example. Consider (with exact sizes omitted for brevity) +--- +--- > f : a -> a -> a -> []a -> [][][]a -> a +--- > xss : [][]a +--- > ys : []a +--- > zsss : [][][]a +--- > w : a +--- > vss : [][]a +--- +--- and the application +--- +--- > f xss ys zsss w vss +--- +--- which will have the following annotations +--- +--- > (f xss) -- AutoMap {R = mempty, M = [][], F = [][]} (1) +--- > ((f xss) ys) -- AutoMap {R = [], M = mempty, F = [][]} (2) +--- > (((f xss) ys) zsss) -- AutoMap {R = mempty, M = [], F = [][][]} (3) +--- > ((((f xss) ys) zsss) w) -- AutoMap {R = [][][][], M = mempty, F = [][][]} (4) +--- > (((((f xss) ys) zsss) w) vss) -- AutoMap {R = [], M = mempty, F = [][][]} (5) +--- +--- This will yield the following mapnest. +--- +--- > map (\zss -> +--- > map (\xs zs vs -> +--- > map (\x y z v -> f x y z (replicate w) v) xs ys zs v) xss zss vss) zsss +--- +--- Let's see how we'd construct this mapnest from the annotations. We construct +--- the nest bottom-up. We have: +--- +--- Application | True level +--- --------------------------- +--- (1) | |[][]| = 2 +--- (2) | |[][]| - |[]| = 1 +--- (3) | |[][][]| = 3 +--- (4) | |[][][]| - |[][][][]| = -1 +--- (5) | |[][][]| - |[]| = 2 +--- +--- We start at level 0. +--- * Any argument with a negative true level of @-n@ will be replicated @n@ times; +--- the exact shapes can be found by removing the @F@ postfix from @R@, +--- i.e. @R = shapes_to_rep_by <> F@. +--- * Any argument with a 0 true level will be included. +--- * For any argument @arg@ with a positive true level, we construct a new parameter +--- whose type is @arg@ with the leading @n@ dimensions (where @n@ is the true level) +--- removed. +--- +--- Following the rules above, @w@ will be replicated once. For the remaining arguments, +--- we create new parameters @x : a, y : a, z : a , v : a@. Hence, level 0 becomes +--- +--- > f x y z (replicate w) v +--- +--- At level l > 0: +--- * There are no replicates. +--- * Any argument with l true level will be included verbatim. +--- * Any argument with true level > l will have a new parameter constructed for it, +--- whose type has the leading @n - l@ dimensions (where @n@ is the true level) removed. +--- * We surround the previous level with a map that binds that levels' new parameters +--- and is passed the current levels' arguments. +--- +--- Following the above recipe for level 1, we create parameters +--- @xs : []a, zs : []a, vs :[]a@ and obtain +--- +--- > map (\x y z v -> f x y z (replicate w) v) xs ys zs vs +--- +--- This process continues until the level is greater than the maximum +--- true level of any application, at which we terminate. + +-- | Expands 'AutoMap' annotations into explicit @map@s and @replicates@. +expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp +expandAMAnnotations e = do + case e of + (AppExp (Apply f args loc) (Info res)) -> do + let ((exts, ams), arg_es) = first unzip $ unzip $ map (first unInfo) $ NE.toList args + f' <- expandAMAnnotations f + arg_es' <- mapM expandAMAnnotations arg_es + let diets = funDiets $ typeOf f + withMapNest loc (zip4 exts ams arg_es diets) $ \args' -> do + inner_f <- setNewType f' $ innerFType (typeOf f') ams + let (_, ret) = unfoldFunType $ typeOf inner_f + + -- when (any (/= mempty) ams) $ + -- traceM $ + -- unlines $ + -- [ "##f'", + -- prettyString $ typeOf f', + -- "##inner_f", + -- prettyString $ typeOf inner_f, + -- "##e", + -- prettyString e, + -- "##ams", + -- show ams + -- ] + pure $ + mkApply inner_f (zip3 exts (repeat mempty) args') $ + res {appResType = snd $ unfoldFunType $ typeOf inner_f} + (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do + x' <- expandAMAnnotations x + y' <- expandAMAnnotations y + withMapNest loc [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> + pure $ + AppExp + ( BinOp + op + (Info t) + (x'', Info (xext, mempty)) + (y'', Info (yext, mempty)) + loc + ) + (Info res {appResType = stripArray (shapeRank $ autoFrame yam) (appResType res)}) + _ -> astMap identityMapper {mapOnExp = expandAMAnnotations} e + where + setNewType e t = astMap identityMapper {mapOnStructType = const $ pure t} e + + funDiets :: TypeBase dim as -> [Diet] + funDiets (Scalar (Arrow _ _ d _ (RetType _ t2))) = d : funDiets t2 + funDiets _ = [] + + dropDims :: Int -> TypeBase dim as -> TypeBase dim as + dropDims n (Scalar (Arrow u p diet t1 (RetType ds t2))) = + Scalar (Arrow u p diet (stripArray n t1) (RetType ds (dropDims n t2))) + dropDims n t = stripArray n t + + innerFType :: TypeBase dim as -> [AutoMap] -> TypeBase dim as + innerFType (Scalar (Arrow u p diet t1 (RetType ds t2))) ams = + Scalar $ Arrow u p diet t1 $ RetType ds $ innerFType' t2 ams + where + innerFType' t [] = t + innerFType' (Scalar (Arrow u p diet t1 (RetType ds t2))) (am : ams) = + Scalar $ Arrow u p diet (dropDims (shapeRank (autoMap am)) t1) $ RetType ds $ innerFType' t2 ams + innerFType' t [am] = dropDims (shapeRank (autoMap am)) t + innerFType' _ _ = error "" + innerFType _ _ = error "" + +type Level = Int + +data AutoMapArg = AutoMapArg + { amArg :: Exp + } + deriving (Show) + +data AutoMapParam = AutoMapParam + { amParam :: Pat ParamType, + amMapDim :: Size, + amDiet :: Diet + } + deriving (Show) + +-- | Builds a map-nest based on the 'AutoMap' annotations. +withMapNest :: + forall m. + (MonadFreshNames m) => + SrcLoc -> + [(Maybe VName, AutoMap, Exp, Diet)] -> + ([Exp] -> m Exp) -> + m Exp +withMapNest loc args f = do + (param_map, arg_map) <- + bimap combineMaps combineMaps . unzip <$> mapM buildArgMap args + buildMapNest param_map arg_map $ maximum $ M.keys arg_map + where + combineMaps :: (Ord k) => [M.Map k v] -> M.Map k [v] + combineMaps = M.unionsWith (<>) . (fmap . fmap) pure + + buildMapNest :: + M.Map Level [AutoMapParam] -> + M.Map Level [AutoMapArg] -> + Level -> + m Exp + buildMapNest _ arg_map 0 = + f $ map amArg $ arg_map M.! 0 + buildMapNest param_map arg_map l = + case map amMapDim $ param_map M.! l of + [] -> error "Malformed param map." + (map_dim : _) -> do + let params = map (\p -> (amDiet p, amParam p)) $ param_map M.! l + args = map amArg $ arg_map M.! l + body <- buildMapNest param_map arg_map (l - 1) + pure $ + mkMap map_dim params body args $ + RetType [] $ + arrayOfWithAliases Unique (Shape [map_dim]) (typeOf body) + + buildArgMap :: + (Maybe VName, AutoMap, Exp, Diet) -> + m (M.Map Level AutoMapParam, M.Map Level AutoMapArg) + buildArgMap (ext, am, arg, diet) = + foldM (mkArgsAndParams arg) mempty $ reverse [0 .. trueLevel am] + where + mkArgsAndParams arg (p_map, a_map) l + | l == 0 = do + let arg' = maybe arg (paramToExp . amParam) (p_map M.!? 1) + rarg <- mkReplicateShape (autoRep am `shapePrefix` autoFrame am) arg' + pure (p_map, M.insert 0 (AutoMapArg rarg) a_map) + | l == trueLevel am = do + p <- mkAMParam (typeOf arg) l + let d = outerDim am l + pure + ( M.insert l (AutoMapParam p d diet) p_map, + M.insert l (AutoMapArg arg) a_map + ) + | l < trueLevel am && l > 0 = do + p <- mkAMParam (typeOf arg) l + let d = outerDim am l + let arg' = + paramToExp $ + amParam $ + p_map M.! (l + 1) + pure + ( M.insert l (AutoMapParam p d diet) p_map, + M.insert l (AutoMapArg arg') a_map + ) + | otherwise = error "Impossible." + + mkAMParam t level = + mkParam ("p_" <> show level) $ argType (level - 1) am t + + trueLevel :: AutoMap -> Int + trueLevel am + | autoMap am == mempty = + max 0 $ shapeRank (autoFrame am) - shapeRank (autoRep am) + | otherwise = + shapeRank $ autoFrame am + + outerDim :: AutoMap -> Int -> Size + outerDim am level = + (!! (trueLevel am - level)) $ shapeDims $ autoFrame am + + argType level am = stripArray (trueLevel am - level) + +mkParam :: (MonadFreshNames m) => String -> TypeBase Size u -> m (Pat ParamType) +mkParam desc t = do + x <- newVName desc + pure $ Id x (Info $ toParam Observe t) mempty + +mkReplicateShape :: (MonadFreshNames m) => Shape Size -> Exp -> m Exp +mkReplicateShape s e = foldM (flip mkReplicate) e s + +mkReplicate :: (MonadFreshNames m) => Exp -> Exp -> m Exp +mkReplicate dim e = do + x <- mkParam "x" (Scalar $ Prim $ Unsigned Int64) + pure $ + mkMap dim [(Observe, x)] e [xs] $ + RetType mempty (arrayOfWithAliases Unique (Shape [dim]) (typeOf e)) + where + xs = + AppExp + ( Range + (Literal (UnsignedValue $ intValue Int64 0) mempty) + Nothing + (UpToExclusive dim) + mempty + ) + ( Info $ AppRes (arrayOf (Shape [dim]) (Scalar $ Prim $ Unsigned Int64)) [] + ) + +mkMap :: Exp -> [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp +mkMap dim params body arrs rettype = + mkApply mapN args (AppRes (toStruct $ retType rettype) []) + where + args = map (Nothing,mempty,) $ lambda : arrs + mapt = foldFunType (zipWith toParam (Observe : map fst params) (typeOf lambda : map typeOf arrs)) rettype + mapN = Var (QualName [] $ VName "map" 0) (Info mapt) mempty + lambda = + Lambda + (map snd params) + body + Nothing + ( Info $ + RetType + (retDims rettype) + (typeOf body `setUniqueness` uniqueness (retType rettype)) + ) + mempty + +paramToExp :: Pat ParamType -> Exp +paramToExp (Id vn (Info t) loc) = + Var (QualName [] vn) (Info $ toStruct t) loc +paramToExp p = error $ prettyString p From e971922f16054fe4ac13b7c8b9e67475326d5798 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 3 Mar 2024 14:13:17 -0800 Subject: [PATCH 157/296] Apostrophes are important, man. --- src/Futhark/Internalise/FullNormalise.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index fe197a842c..c615e49c3d 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -506,7 +506,7 @@ expandAMAnnotations e = do f' <- expandAMAnnotations f arg_es' <- mapM expandAMAnnotations arg_es let diets = funDiets $ typeOf f - withMapNest loc (zip4 exts ams arg_es diets) $ \args' -> do + withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do inner_f <- setNewType f' $ innerFType (typeOf f') ams let (_, ret) = unfoldFunType $ typeOf inner_f From d62614f44df55571c63a988b9a7175f73f83df03 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Sun, 3 Mar 2024 18:00:58 -0800 Subject: [PATCH 158/296] Fixes/clean-up. --- src/Futhark/Internalise/Exps.hs | 24 +++++++++++++++++------- src/Futhark/Internalise/FullNormalise.hs | 15 +-------------- tests/automap/mri-q.fut | 6 ++++-- tests/automap/optionpricing.fut | 12 ++++++------ tests/automap/sgemm.fut | 2 +- 5 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 11b261e425..c5fe0a8fcd 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1625,17 +1625,27 @@ isIntrinsicFunction qname args loc = do handleOps _ _ = Nothing handleSOACs (lam : args) "map" = Just $ \desc -> do - let internaliseVName x = do - es <- map (BasicOp . SubExp) <$> internaliseExp "arg" x - concat <$> mapM (letValExp "arg") es - args' <- concat <$> mapM internaliseVName args - param_ts <- mapM (fmap (I.stripArray 1) . lookupType) args' - map_dim <- (head . I.shapeDims . I.arrayShape) <$> lookupType (head args') + arg_ses <- concat <$> mapM (internaliseExp "arg") args + arg_ts <- mapM subExpType arg_ses + let param_ts = map rowType arg_ts + map_dim = head $ I.shapeDims $ I.arrayShape $ head arg_ts + + arg_ses' <- + zipWithM + ( \p a -> + ensureShape "" mempty (arrayOfRow p map_dim) "" a + ) + param_ts + arg_ses + + args_v'' <- mapM (letExp "" . BasicOp . SubExp) arg_ses' + lambda <- internaliseLambdaCoerce lam param_ts + letTupExp' desc $ Op - $ Screma map_dim args' + $ Screma map_dim args_v'' $ mapSOAC lambda handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index c615e49c3d..52dc0fb0f5 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -509,19 +509,6 @@ expandAMAnnotations e = do withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do inner_f <- setNewType f' $ innerFType (typeOf f') ams let (_, ret) = unfoldFunType $ typeOf inner_f - - -- when (any (/= mempty) ams) $ - -- traceM $ - -- unlines $ - -- [ "##f'", - -- prettyString $ typeOf f', - -- "##inner_f", - -- prettyString $ typeOf inner_f, - -- "##e", - -- prettyString e, - -- "##ams", - -- show ams - -- ] pure $ mkApply inner_f (zip3 exts (repeat mempty) args') $ res {appResType = snd $ unfoldFunType $ typeOf inner_f} @@ -610,7 +597,7 @@ withMapNest loc args f = do pure $ mkMap map_dim params body args $ RetType [] $ - arrayOfWithAliases Unique (Shape [map_dim]) (typeOf body) + arrayOfWithAliases Nonunique (Shape [map_dim]) (typeOf body) buildArgMap :: (Maybe VName, AutoMap, Exp, Diet) -> diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index 8fe26aded6..3a4648c7b9 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -35,5 +35,7 @@ def main_am [numK][numX] entry main [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) - (phiR: [numK]f32) (phiI: [numK]f32) : bool = - main_orig kx ky kz x y z phiR phiI == main_am kx ky kz x y z phiR phiI + (phiR: [numK]f32) (phiI: [numK]f32) = + let (qr, qi) = main_orig kx ky kz x y z phiR phiI + let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI + in and (map2 (==) qr qr_am && qi == qi_am) diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut index c58bc39a0a..c4c916521f 100644 --- a/tests/automap/optionpricing.fut +++ b/tests/automap/optionpricing.fut @@ -68,11 +68,11 @@ def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= prev ^ recM(sob_dirs, i) -entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): []bool = - sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n +entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): bool = + and (sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n) -entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): []bool = - sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x) +entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): bool = + and (sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x)) -entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): []bool = - sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i) +entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): bool = + and (sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i)) diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut index 56dc08eb7e..a31ce0188e 100644 --- a/tests/automap/sgemm.fut +++ b/tests/automap/sgemm.fut @@ -29,4 +29,4 @@ def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) (alpha: f32) (beta: f32) = - main_orig ass bss css alpha beta == main_am ass bss css alpha beta + and (and (main_orig ass bss css alpha beta == main_am ass bss css alpha beta)) From 51c5393d515c8a3ec8f4ab25cfc3ecbd030e471a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 4 Mar 2024 10:47:33 +0100 Subject: [PATCH 159/296] Tuples must have more than one field. --- src/Language/Futhark/Tuple.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Tuple.hs b/src/Language/Futhark/Tuple.hs index a410ae0a5a..63cf7c1188 100644 --- a/src/Language/Futhark/Tuple.hs +++ b/src/Language/Futhark/Tuple.hs @@ -17,7 +17,8 @@ import Language.Futhark.Core (Name, nameFromString, nameToText) areTupleFields :: M.Map Name a -> Maybe [a] areTupleFields fs = let fs' = sortFields fs - in if and $ zipWith (==) (map fst fs') tupleFieldNames + in if length fs' > 1 + && and (zipWith (==) (map fst fs') tupleFieldNames) then Just $ map snd fs' else Nothing From 8260a4da843a25570a9fd463ff2cf49951fe953d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 4 Mar 2024 13:44:37 +0100 Subject: [PATCH 160/296] Start handling overloaded tyvars. --- .../Futhark/TypeChecker/Constraints.hs | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index fc4f1fb5f8..ebac367dfb 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -108,6 +108,16 @@ substTyVar m v = Just (TyVarUnsol {}) -> Nothing Nothing -> Nothing +lookupTyVar :: TyVar -> SolveM (Maybe Type) +lookupTyVar orig = do + tyvars <- gets solverTyVars + let f v = case M.lookup v tyvars of + Nothing -> error $ "Unknown tyvar: " <> prettyNameString v + Just (TyVarSol _ t) -> pure $ Just t + Just (TyVarLink v') -> f v' + Just (TyVarUnsol {}) -> pure Nothing + f orig + -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand -- sides that contain a type variable. @@ -238,11 +248,46 @@ solveCt ct = Nothing -> bad Just eqs -> mapM_ solveCt' eqs +solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () +solveTyVar (tv, (_, TyVarFree {})) = pure () +solveTyVar (tv, (_, TyVarPrim pts)) = do + t <- lookupTyVar tv + case t of + Nothing -> pure () + Just t' + | t' `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + throwError $ + "Type variable " + <> prettyNameText tv + <> " must be one of\n" + <> prettyText pts + <> "\nbut inferred to be\n" + <> prettyText t' +solveTyVar (tv, (_, TyVarRecord fs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Nothing -> pure () + Just (Scalar (Record fs2)) + | all (`M.member` fs2) (M.keys fs1) -> + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(k, (t1, t2)) -> + solveCt $ CtEq t1 t2 + Just tv_t' -> + throwError $ + "Type variable " + <> prettyNameText tv + <> " must be record with fields\n" + <> prettyText (Scalar (Record fs1)) + <> " but inferred to be\n" + <> prettyText tv_t' + solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) solve constraints tyvars = second solution . runExcept . flip execStateT (initialState tyvars) . runSolveM - $ mapM solveCt constraints + $ do + mapM_ solveCt constraints + mapM_ solveTyVar (M.toList tyvars) {-# NOINLINE solve #-} From 5bab301710d0c33293940ed9064f9600b19a14e3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 4 Mar 2024 13:45:21 +0100 Subject: [PATCH 161/296] This order is better. --- src/Language/Futhark/TypeChecker.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 70824ad148..a69722ee2b 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -708,14 +708,14 @@ checkValBind vb = do checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) let entry' = Info (entryPoint params' maybe_tdecl' rettype) <$ entry + vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc + + debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] + case entry' of Just _ -> checkEntryPoint loc tparams' params' maybe_tdecl' rettype _ -> pure () - let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc - - debugTraceM 3 $ unlines ["# Inferred:", prettyString vb'] - pure ( mempty { envVtable = From 9e2c1ee9b77ab60538687e2261d6202eb58839f3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 5 Mar 2024 14:02:50 +0100 Subject: [PATCH 162/296] Refactor to return one list with everything. --- src/Language/Futhark/TypeChecker/Terms.hs | 13 ++++++----- src/Language/Futhark/TypeChecker/Terms2.hs | 26 ++++++++++------------ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index ec30e59689..cdc25b99d6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -22,7 +22,7 @@ import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) import Data.Either -import Data.List (delete, find, genericLength, partition) +import Data.List (delete, find, genericLength, partition, unzip4) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe @@ -1635,12 +1635,13 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = do - (maybe_tysubstss, params', retdecl', bodys') <- + solutions <- Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - case (maybe_tysubstss, bodys') of - ([], _) -> error "impossible" - ([maybe_tysubsts], [body']) -> doChecks (maybe_tysubsts, params', retdecl', body') - (substs, bodies') -> + case solutions of + [(maybe_tysubsts, params', retdecl', body')] -> + doChecks (maybe_tysubsts, params', retdecl', body') + ls -> do + let (_, _, _, bodies') = unzip4 ls typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 1e8dbcefbb..b6953d810f 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -641,7 +641,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctAM r m $ fmap toSComp (toShape m_var <> fframe) ctEq lhs rhs debugTraceM 3 $ - unlines $ + unlines [ "## checkApplyOne", "## fname", prettyString fname, @@ -1168,11 +1168,12 @@ checkValDef :: SrcLoc ) -> TypeM - ( [Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness))], - [Pat ParamType], - Maybe (TypeExp Exp VName), - [Exp] - ) + [ ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), + [Pat ParamType], + Maybe (TypeExp Exp VName), + Exp + ) + ] checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do (params', body', retdecl') <- bindParams tparams params $ \params' -> do @@ -1196,13 +1197,11 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - (cts_tyvars', bodys') <- unzip <$> rankAnalysis loc cts tyvars body' - - solutions <- - forM cts_tyvars' $ - bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + ranks <- rankAnalysis loc cts tyvars body' - forM_ (zip solutions cts_tyvars') $ \(solution, (cts', tyvars')) -> + forM ranks $ \((cts', tyvars'), body'') -> do + solution <- + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' debugTraceM 3 $ unlines [ "## constraints:", @@ -1214,8 +1213,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do in either T.unpack (unlines . map p . M.toList . snd) solution, either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution ] - - pure (solutions, params', retdecl', bodys') + pure (solution, params', retdecl', body'') checkSingleExp :: ExpBase NoInfo VName -> From a6676efdc7f59f9e3a4848dbf84e6cfbfc360abc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 5 Mar 2024 14:48:18 +0100 Subject: [PATCH 163/296] Let-generalise in Terms2. --- src/Language/Futhark/TypeChecker/Terms.hs | 82 +++++++++++----------- src/Language/Futhark/TypeChecker/Terms2.hs | 76 ++++++++++++++------ 2 files changed, 96 insertions(+), 62 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index cdc25b99d6..22ecc9ebc0 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1048,7 +1048,7 @@ checkOneExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' (tparams, _, _) <- @@ -1066,7 +1066,7 @@ checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of Left err -> typeError e' mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' when (hasBinding e'') $ typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ @@ -1492,14 +1492,16 @@ closeOverTypes defname defloc tparams paramts ret substs = do _ -> Nothing pure - ( tparams ++ more_tparams, + ( tparams + ++ more_tparams, injectExt (nubOrd $ retext ++ mapMaybe mkExt (S.toList $ fvVars $ freeInType ret)) ret ) where -- Diet does not matter here. t = foldFunType (map (toParam Observe) paramts) $ RetType [] ret - to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs visible = typeVars t <> fvVars (freeInType t) + to_close_over = + M.filterWithKey (\k _ -> k `S.member` visible) substs (produced_sizes, param_sizes) = dimUses t @@ -1655,39 +1657,39 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = do doChecks (maybe_tysubsts, params', retdecl', body') = case maybe_tysubsts of Left err -> typeError loc mempty $ pretty err - Right (unconstrained, tysubsts) -> runTermTypeM checkExp tysubsts $ do - let unconstrained_tparams = map (\v -> TypeParamType Unlifted v mempty) unconstrained - (tparams', params'', retdecl'', RetType dims rettype', body'') <- - checkBinding (fname, retdecl', unconstrained_tparams <> tparams, params', body', loc) - - -- Since this is a top-level function, we also resolve overloaded - -- types, using either defaults or complaining about ambiguities. - fixOverloadedTypes $ - typeVars rettype' <> foldMap (typeVars . patternType) params'' - - -- Then replace all inferred types in the body and parameters. - body''' <- normTypeFully body'' - params''' <- mapM normTypeFully params'' - retdecl''' <- traverse updateTypes retdecl'' - rettype'' <- normTypeFully rettype' - - -- Check if the function body can actually be evaluated. - causalityCheck body''' - - -- Check for various problems. - mapM_ (mustBeIrrefutable . fmap toStruct) params'' - localChecks body''' - - let ((body'''', updated_ret), errors) = - Consumption.checkValDef - ( fname, - params''', - body''', - RetType dims rettype'', - retdecl''', - loc - ) - - mapM_ throwError errors - - pure (tparams', params''', retdecl''', updated_ret, body'''') + Right (generalised, tysubsts) -> + runTermTypeM checkExp tysubsts $ do + (tparams', params'', retdecl'', RetType dims rettype', body'') <- + checkBinding (fname, retdecl', generalised <> tparams, params', body', loc) + + -- Since this is a top-level function, we also resolve overloaded + -- types, using either defaults or complaining about ambiguities. + fixOverloadedTypes $ + typeVars rettype' <> foldMap (typeVars . patternType) params'' + + -- Then replace all inferred types in the body and parameters. + body''' <- normTypeFully body'' + params''' <- mapM normTypeFully params'' + retdecl''' <- traverse updateTypes retdecl'' + rettype'' <- normTypeFully rettype' + + -- Check if the function body can actually be evaluated. + causalityCheck body''' + + -- Check for various problems. + mapM_ (mustBeIrrefutable . fmap toStruct) params'' + localChecks body''' + + let ((body'''', updated_ret), errors) = + Consumption.checkValDef + ( fname, + params''', + body''', + RetType dims rettype'', + retdecl''', + loc + ) + + mapM_ throwError errors + + pure (tparams', params''', retdecl''', updated_ret, body'''') diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b6953d810f..d9fc4f82d0 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -49,6 +49,7 @@ import Control.Monad.State import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) +import Data.Either (partitionEithers) import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (NoLoc)) @@ -1125,17 +1126,17 @@ checkExp (Coerce e te NoInfo loc) = do pure $ Coerce e' te' (Info (toStruct st)) loc doDefault :: - S.Set VName -> + [VName] -> VName -> Either [PrimType] (TypeBase () NoUniqueness) -> TermM (TypeBase () NoUniqueness) doDefault tyvars_at_toplevel v (Left pts) | Signed Int32 `elem` pts = do - when (v `S.member` tyvars_at_toplevel) $ + when (v `elem` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." pure $ Scalar $ Prim $ Signed Int32 | FloatType Float64 `elem` pts = do - when (v `S.member` tyvars_at_toplevel) $ + when (v `elem` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to f64." pure $ Scalar $ Prim $ FloatType Float64 | otherwise = @@ -1152,13 +1153,29 @@ doDefault _ _ (Right t) = pure t -- some type variables becoming known, so we have to perform -- substitutions on the RHS of the substitutions afterwards. doDefaults :: - S.Set VName -> + [VName] -> M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -> TermM (M.Map TyVar (TypeBase () NoUniqueness)) doDefaults tyvars_at_toplevel substs = do substs' <- M.traverseWithKey (doDefault tyvars_at_toplevel) substs pure $ M.map (substTyVars (`M.lookup` substs')) substs' +generalise :: + StructType -> [VName] -> Solution -> ([TypeParam], [VName]) +generalise fun_t unconstrained solution = + -- Candidates for let-generalisation are those type variables that + -- are used in fun_t. + let visible = foldMap expandTyVars $ typeVars fun_t + onTyVar v + | v `S.member` visible = Left $ TypeParamType Unlifted v mempty + | otherwise = Right v + in partitionEithers $ map onTyVar unconstrained + where + expandTyVars v = + case M.lookup v solution of + Just (Right t) -> foldMap expandTyVars $ typeVars t + _ -> S.singleton v + checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1168,7 +1185,7 @@ checkValDef :: SrcLoc ) -> TypeM - [ ( Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), + [ ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1197,23 +1214,38 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - ranks <- rankAnalysis loc cts tyvars body' - - forM ranks $ \((cts', tyvars'), body'') -> do - solution <- - bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' - debugTraceM 3 $ - unlines - [ "## constraints:", - unlines $ map prettyString cts', - "## tyvars':", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', - "## solution:", - let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, - either (const mempty) (unlines . ("## unconstrained:" :) . map prettyNameString . fst) solution - ] - pure (solution, params', retdecl', body'') + mapM (onRankSolution params' retdecl') =<< rankAnalysis loc cts tyvars body' + where + onRankSolution params' retdecl' ((cts', tyvars'), body'') = do + solution <- + bitraverse pure (onTySolution params' body'') $ solve cts' tyvars' + debugTraceM 3 $ + unlines + [ "## constraints:", + unlines $ map prettyString cts', + "## tyvars':", + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', + "## solution:", + let p (v, t) = prettyNameString v <> " => " <> prettyString t + in either T.unpack (unlines . map p . M.toList . snd) solution, + either (const mempty) (unlines . ("## generalised:" :) . map prettyString . fst) solution + ] + pure (solution, params', retdecl', body'') + + onTySolution params' body' (unconstrained, solution) = do + let fun_t = + foldFunType + (map patternType params') + (RetType [] $ toRes Nonunique (typeOf body')) + (generalised, unconstrained') = + generalise fun_t unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ) checkSingleExp :: ExpBase NoInfo VName -> From d85bd68e0716ac1a445523753e57c53a27c66aed Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 5 Mar 2024 09:19:19 -0800 Subject: [PATCH 164/296] Fix return types of partially applied functions in AM nests. --- src/Futhark/Internalise/FullNormalise.hs | 12 ++++++++---- src/Language/Futhark/Prop.hs | 10 ++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 52dc0fb0f5..9b27a6cf46 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -354,7 +354,7 @@ getOrdering final (AppExp (Match expr cs loc) resT) = do -- a complete separtion of states. transformBody :: (MonadFreshNames m) => Exp -> m Exp transformBody e = do - (e', pre_eval) <- runOrdering . getOrdering True =<< expandAMAnnotations e + (e', pre_eval) <- runOrdering $ getOrdering True e pure $ foldl f e' pre_eval where appRes = case e of @@ -368,7 +368,7 @@ transformBody e = do transformValBind :: (MonadFreshNames m) => ValBind -> m ValBind transformValBind valbind = do - body' <- transformBody $ valBindBody valbind + body' <- transformBody <=< expandAMAnnotations $ valBindBody valbind pure $ valbind {valBindBody = body'} transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] @@ -508,10 +508,14 @@ expandAMAnnotations e = do let diets = funDiets $ typeOf f withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do inner_f <- setNewType f' $ innerFType (typeOf f') ams - let (_, ret) = unfoldFunType $ typeOf inner_f + let rettype = + case unfoldFunTypeWithRet $ typeOf inner_f of + Nothing -> error "Function type expected." + Just (ptypes, f_ret) -> + foldFunType (drop (length args') ptypes) f_ret pure $ mkApply inner_f (zip3 exts (repeat mempty) args') $ - res {appResType = snd $ unfoldFunType $ typeOf inner_f} + res {appResType = rettype} (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x y' <- expandAMAnnotations y diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 280b531286..d45cc3d294 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -53,6 +53,7 @@ module Language.Futhark.Prop arrayShape, orderZero, unfoldFunType, + unfoldFunTypeWithRet, foldFunType, typeVars, isAccType, @@ -522,6 +523,15 @@ unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = in (second (const d) t1 : ps, r) unfoldFunType t = ([], toStruct t) +-- | Extract the parameter types and 'RetTypeBase' from a function type. +-- If the type is not an arrow type, returns 'Nothing'. +unfoldFunTypeWithRet :: TypeBase dim as -> Maybe ([TypeBase dim Diet], RetTypeBase dim Uniqueness) +unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 (RetType _ t2@(Scalar Arrow {})))) = do + (ps, r) <- unfoldFunTypeWithRet t2 + pure (second (const d) t1 : ps, r) +unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 r@RetType {})) = Just ([second (const d) t1], r) +unfoldFunTypeWithRet _ = Nothing + -- | The type scheme of a value binding, comprising the type -- parameters and the actual type. valBindTypeScheme :: ValBindBase Info VName -> ([TypeParamBase VName], StructType) From 43c94793c704e97089b2905eb5d5fa90b2c671d8 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 5 Mar 2024 09:27:37 -0800 Subject: [PATCH 165/296] Only return a single solution when doing rank analysis. --- src/Language/Futhark/TypeChecker/Rank.hs | 21 ++++++++++++++++++++- src/Language/Futhark/TypeChecker/Terms.hs | 16 ++-------------- src/Language/Futhark/TypeChecker/Terms2.hs | 13 ++++++------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9672e9aa60..aa04ea1401 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -1,4 +1,8 @@ -module Language.Futhark.TypeChecker.Rank (rankAnalysis) where +module Language.Futhark.TypeChecker.Rank + ( rankAnalysis, + rankAnalysis1, + ) +where import Control.Monad.Reader import Control.Monad.State @@ -13,6 +17,7 @@ import Futhark.Solve.GLPK import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) import Futhark.Solve.LP qualified as LP import Futhark.Util (debugTraceM) +import Futhark.Util.Pretty import Language.Futhark hiding (ScalarType) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints @@ -211,6 +216,20 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs +rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m (([Ct], TyVars), Exp) +rankAnalysis1 loc cs tyVars body = do + solutions <- rankAnalysis loc cs tyVars body + case solutions of + [sol] -> pure sol + sols -> do + let (_, bodies') = unzip sols + typeError loc mempty $ + stack $ + [ "Rank ILP is ambiguous.", + "Choices:" + ] + ++ map pretty bodies' + rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] rankAnalysis loc cs tyVars body = do diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 22ecc9ebc0..8f039c0d9b 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1636,20 +1636,8 @@ checkFunDef :: ResRetType, Exp ) -checkFunDef (fname, retdecl, tparams, params, body, loc) = do - solutions <- - Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) - case solutions of - [(maybe_tysubsts, params', retdecl', body')] -> - doChecks (maybe_tysubsts, params', retdecl', body') - ls -> do - let (_, _, _, bodies') = unzip4 ls - typeError loc mempty $ - stack $ - [ "Rank ILP is ambiguous.", - "Choices:" - ] - ++ map pretty bodies' +checkFunDef (fname, retdecl, tparams, params, body, loc) = + doChecks =<< Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) where -- TODO: Print out the possibilities. (And also potentially eliminate --- some of the possibilities to disambiguate). diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d9fc4f82d0..da553c4089 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1185,12 +1185,11 @@ checkValDef :: SrcLoc ) -> TypeM - [ ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), - [Pat ParamType], - Maybe (TypeExp Exp VName), - Exp - ) - ] + ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), + [Pat ParamType], + Maybe (TypeExp Exp VName), + Exp + ) checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do (params', body', retdecl') <- bindParams tparams params $ \params' -> do @@ -1214,7 +1213,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - mapM (onRankSolution params' retdecl') =<< rankAnalysis loc cts tyvars body' + onRankSolution params' retdecl' =<< rankAnalysis1 loc cts tyvars body' where onRankSolution params' retdecl' ((cts', tyvars'), body'') = do solution <- From 155f2bbb4f0a195685e63ccb78951bbad433910b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 14:31:27 +0100 Subject: [PATCH 166/296] Strangle some warnings. --- src/Language/Futhark/TypeChecker/Terms.hs | 28 ++++++++--------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 8f039c0d9b..26e0705ca8 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -22,7 +22,7 @@ import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) import Data.Either -import Data.List (delete, find, genericLength, partition, unzip4) +import Data.List (delete, find, genericLength, partition) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe @@ -465,7 +465,7 @@ checkExp (Coerce e te _ loc) = do t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc -checkExp e@(AppExp (Apply fe args loc) _) = do +checkExp (AppExp (Apply fe args loc) _) = do fe' <- checkExp fe let ams = fmap (snd . unInfo . fst) args args' <- mapM (checkExp . snd) args @@ -727,7 +727,7 @@ checkExp (OpSection op (Info op_t) loc) = do checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do ftype <- lookupVar loc op op_t e' <- checkExp e - (t1, rt, argext, retext, am) <- checkApply loc (Just op, 0) ftype e' am + (t1, rt, argext, retext, am') <- checkApply loc (Just op, 0) ftype e' am case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 (RetType ds rt2))) -> pure $ @@ -735,8 +735,8 @@ checkExp (OpSectionLeft op (Info op_t) e (Info (_, _, _, am), _) _ loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1, argext, am), Info (m2, toParam d2 t2)) - (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am) rt2, Info retext) + (Info (m1, toParam d1 t1, argext, am'), Info (m2, toParam d2 t2)) + (Info $ RetType ds $ arrayOfWithAliases (uniqueness rt2) (autoFrame am') rt2, Info retext) loc _ -> typeError loc mempty $ @@ -746,7 +746,7 @@ checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do - (t2', arrow', argext, _, am) <- + (t2', arrow', argext, _, am') <- checkApply loc (Just op, 1) @@ -760,8 +760,8 @@ checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do op (Info ftype) e' - (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am)) - (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am) ret') + (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext, am')) + (Info $ RetType dims2' $ arrayOfWithAliases (uniqueness ret') (autoFrame am') ret') loc _ -> error $ "OpSectionRight: impossible type\n" <> prettyString arrow' _ -> @@ -922,14 +922,6 @@ dimUses = flip execState mempty . traverseDims f where fv = freeInExp e `freeWithout` bound --- | Try to find out how many dimensions of the argument we are --- mapping. Returns the shape mapped and the remaining type. -stripToMatch :: StructType -> StructType -> (Shape Size, StructType) -stripToMatch paramt argt | toStructural paramt == toStructural argt = (mempty, argt) -stripToMatch paramt (Array _ (Shape (d : ds)) argt) = - first (Shape [d] <>) $ stripToMatch paramt $ arrayOf (Shape ds) (Scalar argt) -stripToMatch _ argt = (mempty, argt) - splitArrayAt :: Int -> StructType -> (Shape Size, StructType) splitArrayAt x t = (Shape $ take x $ shapeDims $ arrayShape t, stripArray x t) @@ -1018,14 +1010,14 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d in pure (Nothing, applySubst parsubst $ toStruct tp2') _ -> pure (Nothing, toStruct tp2') - let am = + let am' = AutoMap { autoRep = am_rep_shape, autoMap = am_map_shape, autoFrame = am_frame_shape } - pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am) + pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am') where distribute :: TypeBase dim u -> TypeBase dim u distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = From bc063036c1a51d949f29ec9d6a65063b31dd6ab1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 14:59:19 +0100 Subject: [PATCH 167/296] Fix typo. --- src/Language/Futhark/TypeChecker/Terms.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 26e0705ca8..2c70460821 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1017,7 +1017,7 @@ checkApply loc fn@(fname, _) ft@(Scalar (Arrow _ pname _ tp1 tp2)) argexp am = d autoFrame = am_frame_shape } - pure (tp1, distribute (arrayOf (autoMap am) tp2''), argext, ext, am') + pure (tp1, distribute (arrayOf (autoMap am') tp2''), argext, ext, am') where distribute :: TypeBase dim u -> TypeBase dim u distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = From 07b507d15ee73c960adc605109084089aede8871 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 15:06:21 +0100 Subject: [PATCH 168/296] Remove unneeded things. --- futhark.cabal | 1 - src/Futhark/IR/Syntax/Core.hs | 6 ------ 2 files changed, 7 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index f4fb4062bd..b6f570ff80 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -495,7 +495,6 @@ library , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 -- remove me later - , process , glpk-hs , silently diff --git a/src/Futhark/IR/Syntax/Core.hs b/src/Futhark/IR/Syntax/Core.hs index 982fadcdec..227c25b23b 100644 --- a/src/Futhark/IR/Syntax/Core.hs +++ b/src/Futhark/IR/Syntax/Core.hs @@ -15,7 +15,6 @@ module Futhark.IR.Syntax.Core ShapeBase (..), Shape, stripDims, - takeDims, Ext (..), ExtSize, ExtShape, @@ -129,11 +128,6 @@ instance Monoid (ShapeBase d) where stripDims :: Int -> ShapeBase d -> ShapeBase d stripDims n (Shape dims) = Shape $ drop n dims --- | @takeDims n shape@ takes the outer @n@ dimensions from --- @shape@. If @shape@ has m <= n dimensions, it returns $shape$. -takeDims :: Int -> ShapeBase d -> ShapeBase d -takeDims n (Shape dims) = Shape $ take n dims - -- | The size of an array as a list of subexpressions. If a variable, -- that variable must be in scope where this array is used. type Shape = ShapeBase SubExp From a49537458d8f36ff0f6704a17f331d8f90e65219 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 15:15:45 +0100 Subject: [PATCH 169/296] Off-by-truth. --- src/Language/Futhark/TypeChecker/Consumption.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 5ebe2996df..8c92e54d20 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -489,7 +489,7 @@ consumeAsNeeded loc pt t = checkArg :: [(Exp, TypeAliases)] -> ParamType -> AutoMap -> Exp -> CheckM (Exp, TypeAliases) checkArg prev p_t am e = do ((e', e_als), e_cons) <- - contain $ if autoRep am == mempty then noAliases e else checkExp e + contain $ if autoRep am /= mempty then noAliases e else checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ From c11e72e88d25f0994e0ee2eafb9b1792f119c142 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 15:46:09 +0100 Subject: [PATCH 170/296] Add bindingParam. --- .../Futhark/TypeChecker/Terms/Loop.hs | 40 +++++++++---------- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 +++++- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index 51d1c8ceba..e1afdfdb4c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -223,16 +223,15 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do uboundexp' <- checkExp uboundexp it <- expType uboundexp' let i' = i {identType = Info it} - bindingIdent i' . bindingPat [] mergepat merge_t $ - \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - For i' uboundexp', - loopbody' - ) + bindingIdent i' . bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + For i' uboundexp', + loopbody' + ) ForIn xpat e -> do (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e @@ -241,22 +240,21 @@ checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do _ | Just t' <- peelArray 1 t -> bindingPat [] xpat t' $ \xpat' -> - bindingPat [] mergepat merge_t $ - \mergepat' -> incLevel $ do - loopbody' <- checkExp loopbody - (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' - pure - ( sparams, - mergepat'', - ForIn (fmap toStruct xpat') e', - loopbody' - ) + bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do + loopbody' <- checkExp loopbody + (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' + pure + ( sparams, + mergepat'', + ForIn (fmap toStruct xpat') e', + loopbody' + ) | otherwise -> typeError (srclocOf e) mempty $ "Iteratee of a for-in loop must be an array, but expression has type" <+> pretty t While cond -> - bindingPat [] mergepat merge_t $ \mergepat' -> + bindingParam mergepat merge_t $ \mergepat' -> incLevel $ do cond' <- checkExp cond diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 4485c20f4c..b86dd63616 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -2,6 +2,7 @@ module Language.Futhark.TypeChecker.Terms.Pat ( binding, bindingParams, + bindingParam, bindingPat, bindingIdent, bindingSizes, @@ -15,7 +16,6 @@ import Data.List (find, isPrefixOf, sort) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S -import Futhark.Util import Futhark.Util.Pretty hiding (group, space) import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV) @@ -204,6 +204,16 @@ checkPat sizes p t m = do [] -> m p' +-- | Check and bind a single parameter. +bindingParam :: + Pat ParamType -> + StructType -> + (Pat ParamType -> TermTypeM a) -> + TermTypeM a +bindingParam p t m = do + checkPat mempty p (Ascribed t) $ \p' -> + binding (patIdents (fmap toStruct p')) $ m p' + -- | Check and bind a @let@-pattern. bindingPat :: [SizeBinder VName] -> From 47d24deb79809c8d2d7be7e02bc199764e825a0d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 16:33:27 +0100 Subject: [PATCH 171/296] Allow touching TypeExps here. --- src/Language/Futhark/Traversals.hs | 42 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index 94b440b2ff..cd889944cd 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -27,6 +27,7 @@ module Language.Futhark.Traversals where import Data.Bifunctor +import Data.Bitraversable import Data.List.NonEmpty qualified as NE import Language.Futhark.Syntax @@ -337,31 +338,36 @@ instance ASTMappable (IdentBase Info VName StructType) where astMap tv (Ident name (Info t) loc) = Ident name <$> (Info <$> mapOnStructType tv t) <*> pure loc -traversePat :: (Monad m) => (t1 -> m t2) -> PatBase Info VName t1 -> m (PatBase Info VName t2) -traversePat f (Id name (Info t) loc) = +traversePat :: + (Monad m) => + (t1 -> m t2) -> + (ExpBase Info VName -> m (ExpBase Info VName)) -> + PatBase Info VName t1 -> + m (PatBase Info VName t2) +traversePat f _ (Id name (Info t) loc) = Id name <$> (Info <$> f t) <*> pure loc -traversePat f (TuplePat pats loc) = - TuplePat <$> mapM (traversePat f) pats <*> pure loc -traversePat f (RecordPat fields loc) = - RecordPat <$> mapM (traverse $ traversePat f) fields <*> pure loc -traversePat f (PatParens pat loc) = - PatParens <$> traversePat f pat <*> pure loc -traversePat f (PatAscription pat t loc) = - PatAscription <$> traversePat f pat <*> pure t <*> pure loc -traversePat f (Wildcard (Info t) loc) = +traversePat f g (TuplePat pats loc) = + TuplePat <$> mapM (traversePat f g) pats <*> pure loc +traversePat f g (RecordPat fields loc) = + RecordPat <$> mapM (traverse $ traversePat f g) fields <*> pure loc +traversePat f g (PatParens pat loc) = + PatParens <$> traversePat f g pat <*> pure loc +traversePat f g (PatAscription pat t loc) = + PatAscription <$> traversePat f g pat <*> bitraverse g pure t <*> pure loc +traversePat f _ (Wildcard (Info t) loc) = Wildcard <$> (Info <$> f t) <*> pure loc -traversePat f (PatLit v (Info t) loc) = +traversePat f _ (PatLit v (Info t) loc) = PatLit v <$> (Info <$> f t) <*> pure loc -traversePat f (PatConstr n (Info t) ps loc) = - PatConstr n <$> (Info <$> f t) <*> mapM (traversePat f) ps <*> pure loc -traversePat f (PatAttr attr p loc) = - PatAttr attr <$> traversePat f p <*> pure loc +traversePat f g (PatConstr n (Info t) ps loc) = + PatConstr n <$> (Info <$> f t) <*> mapM (traversePat f g) ps <*> pure loc +traversePat f g (PatAttr attr p loc) = + PatAttr attr <$> traversePat f g p <*> pure loc instance ASTMappable (PatBase Info VName StructType) where - astMap tv = traversePat $ mapOnStructType tv + astMap tv = traversePat (mapOnStructType tv) (mapOnExp tv) instance ASTMappable (PatBase Info VName ParamType) where - astMap tv = traversePat $ mapOnParamType tv + astMap tv = traversePat (mapOnParamType tv) (mapOnExp tv) instance ASTMappable (FieldBase Info VName) where astMap tv (RecordFieldExplicit name e loc) = From b34f3ab8e18eeb961c09acab4d325cf81fcf8cde Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 6 Mar 2024 16:34:57 +0100 Subject: [PATCH 172/296] Also perform AM on expressions in params. We are certainly missing params in expressions, such as in Lambda. --- src/Language/Futhark/TypeChecker/Rank.hs | 39 ++++++++++++---------- src/Language/Futhark/TypeChecker/Terms2.hs | 6 ++-- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index aa04ea1401..9767c35902 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -123,7 +123,7 @@ addCt (CtAM r m f) = do b_m <- binVar m b_max <- VName "c_max" <$> incCounter tr <- VName ("T_" <> baseName r) <$> incCounter - addConstraints $ [bin b_max, var b_max ~<=~ var tr] + addConstraints [bin b_max, var b_max ~<=~ var tr] addConstraints $ oneIsZero (b_r, r) (b_m, m) addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) addObj m @@ -216,13 +216,13 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m (([Ct], TyVars), Exp) -rankAnalysis1 loc cs tyVars body = do - solutions <- rankAnalysis loc cs tyVars body +rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), [Pat ParamType], Exp) +rankAnalysis1 loc cs tyVars params body = do + solutions <- rankAnalysis loc cs tyVars params body case solutions of [sol] -> pure sol sols -> do - let (_, bodies') = unzip sols + let (_, _, bodies') = unzip3 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -230,13 +230,14 @@ rankAnalysis1 loc cs tyVars body = do ] ++ map pretty bodies' -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> Exp -> m [(([Ct], TyVars), Exp)] -rankAnalysis _ [] tyVars body = pure [(([], tyVars), body)] -rankAnalysis loc cs tyVars body = do +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), [Pat ParamType], Exp)] +rankAnalysis _ [] tyVars params body = pure [(([], tyVars), params, body)] +rankAnalysis loc cs tyVars params body = do rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps let bodys = map (`updAM` body) rank_maps - pure $ zip cts_tyvars' bodys + params' = map ((`map` params) . updAMPat) rank_maps + pure $ zip3 cts_tyvars' params' bodys where cs' = foldMap (splitFuncs . distribute) cs splitFuncs @@ -325,7 +326,7 @@ rankToShape x = do addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks - if (fromMaybe 0 (rs M.!? t) == 0) + if fromMaybe 0 (rs M.!? t) == 0 then do old_tyvars <- asks envTyVars case old_tyvars M.!? t of @@ -334,7 +335,7 @@ addRankInfo t = do -- -- is anyPrimType right here? -- modify $ -- \s -> s {substTyVars = M.insert t (lvl, TyVarPrim anyPrimType) $ substTyVars s} - _ -> do + _ -> pure () else do new_vars <- gets substNewVars @@ -381,13 +382,7 @@ updAM rank_map e = case e of AppExp (Apply f args loc) res -> let f' = updAM rank_map f - args' = - fmap - ( bimap - (fmap $ bimap id upd) - (updAM rank_map) - ) - args + args' = fmap (bimap (fmap $ second upd) (updAM rank_map)) args in AppExp (Apply f' args' loc) res AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res @@ -403,3 +398,11 @@ updAM rank_map e = identityMapper { mapOnExp = pure . updAM rank_map } + +updAMPat :: M.Map VName Int -> Pat ParamType -> Pat ParamType +updAMPat rank_map p = runIdentity $ astMap m p + where + m = + identityMapper + { mapOnExp = pure . updAM rank_map + } diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index da553c4089..376576a13b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1213,9 +1213,9 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars ] - onRankSolution params' retdecl' =<< rankAnalysis1 loc cts tyvars body' + onRankSolution retdecl' =<< rankAnalysis1 loc cts tyvars params' body' where - onRankSolution params' retdecl' ((cts', tyvars'), body'') = do + onRankSolution retdecl' ((cts', tyvars'), params', body'') = do solution <- bitraverse pure (onTySolution params' body'') $ solve cts' tyvars' debugTraceM 3 $ @@ -1267,7 +1267,7 @@ checkSizeExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars - (cts_tyvars', es') <- unzip <$> rankAnalysis (srclocOf e) cts tyvars e' + (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars [] e' solutions <- forM cts_tyvars' $ From 6305172d9063a9bbde839cd279ecefc57b89b4ad Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Wed, 6 Mar 2024 08:28:59 -0800 Subject: [PATCH 173/296] Remove lingering `PuLP` stuff. --- shell.nix | 3 --- src/Futhark/Solve/LP.hs | 35 ----------------------------------- 2 files changed, 38 deletions(-) diff --git a/shell.nix b/shell.nix index d5199b0c02..a5ddb63ab0 100644 --- a/shell.nix +++ b/shell.nix @@ -52,9 +52,6 @@ pkgs.stdenv.mkDerivation { python3Packages.sphinx python3Packages.sphinxcontrib-bibtex imagemagick # needed for literate tests - # remove (needed for PuLP) - python - cbc glpk ] ++ lib.optionals (stdenv.isLinux) diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index a2224617ea..f1b7d18939 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -28,7 +28,6 @@ module Futhark.Solve.LP (~<=~), (~>=~), rowEchelonLPE, - linearProgToPulp, ) where @@ -168,40 +167,6 @@ instance (Ord v) => Vars (LinearProg v a) v where vars (objective lp) <> foldMap vars (constraints lp) --- For debugging -linearProgToPulp :: (Unbox a, IsName v, Ord v, Pretty a, Eq a, Num a) => LinearProg v a -> String -linearProgToPulp prog = - map rm_subscript $ - unlines - [ "from pulp import *", - "prob = LpProblem('', " <> lptype <> ")", - unlines vars, - unlines $ map (("prob += " <>) . prettyString) $ constraints prog, - "status = prob.solve()", - "print(f'status: {status}')", - unlines res - ] - where - lptype = - case optType prog of - Maximize -> "LpMaximize" - Minimize -> "LpMinimize" - prog_vars = Map.elems $ snd $ linearProgToLP prog - vars = - map - ( \v -> - show (prettyName v) - <> " = " - <> "LpVariable(" - <> "'" - <> show (prettyName v) - <> "_'" - <> ", lowBound = 0, cat = 'Integer')" - ) - prog_vars - res = map (\v -> "print(f'" <> show (prettyName v) <> ": {value(" <> show (prettyName v) <> ")}')") prog_vars - rm_subscript x = fromMaybe x $ lookup x $ zip "₀₁₂₃₄₅₆₇₈₉" "0123456789" - bigM :: (Num a) => a bigM = 2 ^ 10 From c8a4348d5448fa6e19e77044b086477a31b8097a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 12:49:10 +0100 Subject: [PATCH 174/296] Start adding location info. --- .../Futhark/TypeChecker/Constraints.hs | 88 +++++++++++-------- src/Language/Futhark/TypeChecker/Rank.hs | 10 +-- src/Language/Futhark/TypeChecker/Terms.hs | 6 +- src/Language/Futhark/TypeChecker/Terms2.hs | 23 ++--- 4 files changed, 71 insertions(+), 56 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ebac367dfb..e78bf993b2 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,6 @@ module Language.Futhark.TypeChecker.Constraints - ( SVar, + ( Reason (..), + SVar, SComp (..), Type, toType, @@ -16,14 +17,22 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor +import Data.Loc import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S -import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Monad (TypeError (..)) import Language.Futhark.TypeChecker.Types (substTyVars) +-- | The reason for a type constraint. Used to generate type error +-- messages. +newtype Reason = Reason + { reasonLoc :: Loc + } + deriving (Eq, Ord, Show) + type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -62,23 +71,31 @@ instance Pretty Ct where type Constraints = [Ct] --- | Information about a type variable. +-- | Information about a type variable. Every type variable is +-- associated with a location, which is the original syntax element +-- that it is the type of. data TyVarInfo = -- | Can be substituted with anything. - TyVarFree + TyVarFree Loc | -- | Can only be substituted with these primitive types. - TyVarPrim [PrimType] + TyVarPrim Loc [PrimType] | -- | Must be a record with these fields. - TyVarRecord (M.Map Name Type) + TyVarRecord Loc (M.Map Name Type) | -- | Must be a sum type with these fields. - TyVarSum (M.Map Name [Type]) + TyVarSum Loc (M.Map Name [Type]) deriving (Show, Eq) instance Pretty TyVarInfo where - pretty TyVarFree = "free" - pretty (TyVarPrim pts) = "∈" <+> pretty pts - pretty (TyVarRecord fs) = pretty $ Scalar $ Record fs - pretty (TyVarSum cs) = pretty $ Scalar $ Sum cs + pretty (TyVarFree _) = "free" + pretty (TyVarPrim _ pts) = "∈" <+> pretty pts + pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs + pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs + +instance Located TyVarInfo where + locOf (TyVarFree loc) = loc + locOf (TyVarPrim loc _) = loc + locOf (TyVarRecord loc _) = loc + locOf (TyVarSum loc _) = loc type TyVar = VName @@ -134,20 +151,20 @@ solution s = mkSubst (TyVarLink v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (TyVarUnsol _ (TyVarPrim pts)) = Just $ Left pts + mkSubst (TyVarUnsol _ (TyVarPrim _ pts)) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, TyVarUnsol _ TyVarFree) = Just v + unconstrained (v, TyVarUnsol _ (TyVarFree _)) = Just v unconstrained _ = Nothing -newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except T.Text) a} - deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError T.Text) +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) occursCheck :: VName -> Type -> SolveM () occursCheck v tp = do vars <- gets solverTyVars let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . throwError . docText $ + when (v `S.member` typeVars tp') . throwError . TypeError mempty mempty $ "Occurs check: cannot instantiate" <+> prettyName v <+> "with" @@ -166,7 +183,7 @@ linkTyVar v t = do modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} tyvars' <- case (M.lookup v tyvars, M.lookup t tyvars) of - (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree)) -> + (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl (TyVarFree _))) -> pure $ M.insert t (TyVarUnsol lvl info) tyvars -- TODO: handle more cases. _ -> pure tyvars @@ -211,7 +228,7 @@ solveCt ct = CtEq t1 t2 -> solveCt' (t1, t2) CtAM {} -> pure () -- Good vibes only. where - bad = throwError $ "Unsolvable: " <> prettyText ct + bad = throwError $ TypeError mempty mempty $ "Unsolvable:" <+> pretty ct solveCt' (t1, t2) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of @@ -249,39 +266,36 @@ solveCt ct = Just eqs -> mapM_ solveCt' eqs solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () -solveTyVar (tv, (_, TyVarFree {})) = pure () -solveTyVar (tv, (_, TyVarPrim pts)) = do +solveTyVar (_, (_, TyVarFree {})) = pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do t <- lookupTyVar tv case t of Nothing -> pure () Just t' | t' `elem` map (Scalar . Prim) pts -> pure () | otherwise -> - throwError $ - "Type variable " - <> prettyNameText tv - <> " must be one of\n" - <> prettyText pts - <> "\nbut inferred to be\n" - <> prettyText t' -solveTyVar (tv, (_, TyVarRecord fs1)) = do + throwError . TypeError loc mempty $ + "Type must be one of" + indent 2 (pretty pts) + "but inferred to be" + indent 2 (pretty t') +solveTyVar (tv, (_, TyVarRecord loc fs1)) = do tv_t <- lookupTyVar tv case tv_t of Nothing -> pure () Just (Scalar (Record fs2)) | all (`M.member` fs2) (M.keys fs1) -> - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(k, (t1, t2)) -> + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> solveCt $ CtEq t1 t2 Just tv_t' -> throwError $ - "Type variable " - <> prettyNameText tv - <> " must be record with fields\n" - <> prettyText (Scalar (Record fs1)) - <> " but inferred to be\n" - <> prettyText tv_t' - -solve :: Constraints -> TyVars -> Either T.Text ([VName], Solution) + TypeError loc mempty $ + "Type must be record with fields" + indent 2 (pretty (Scalar (Record fs1))) + "but inferred to be" + indent 2 (pretty tv_t') + +solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) solve constraints tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9767c35902..e42ea67d52 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -130,12 +130,12 @@ addCt (CtAM r m f) = do addObj tr addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () -addTyVarInfo _ (_, TyVarFree) = pure () -addTyVarInfo tv (_, TyVarPrim _) = +addTyVarInfo _ (_, TyVarFree _) = pure () +addTyVarInfo tv (_, TyVarPrim {}) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarRecord _) = +addTyVarInfo tv (_, TyVarRecord {}) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarSum _) = +addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: [Ct] -> TyVars -> LinearProg @@ -346,7 +346,7 @@ addRankInfo t = do old_tyvars <- asks envTyVars let info = fromJust $ old_tyvars M.!? t modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree mempty) $ substTyVars s} class SubstRanks a where substRanks :: (MonadTypeChecker m) => a -> SubstT m a diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 2c70460821..c54c0b5805 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1039,7 +1039,7 @@ checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) checkOneExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of - Left err -> typeError e' mempty $ pretty err + Left err -> throwError err Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' @@ -1057,7 +1057,7 @@ checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of - Left err -> typeError e' mempty $ pretty err + Left err -> throwError err Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' when (hasBinding e'') $ @@ -1636,7 +1636,7 @@ checkFunDef (fname, retdecl, tparams, params, body, loc) = doChecks (maybe_tysubsts, params', retdecl', body') = case maybe_tysubsts of - Left err -> typeError loc mempty $ pretty err + Left err -> throwError err Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do (tparams', params'', retdecl'', RetType dims rettype', body'') <- diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 376576a13b..d73c6758d9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -220,8 +220,8 @@ incCounter = do tyVarType :: u -> TyVar -> TypeBase dim u tyVarType u v = Scalar $ TypeVar u (qualName v) [] -newTyVarWith :: (Located loc) => loc -> Name -> TyVarInfo -> TermM TyVar -newTyVarWith _loc desc info = do +newTyVarWith :: Name -> TyVarInfo -> TermM TyVar +newTyVarWith desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i lvl <- curLevel @@ -229,24 +229,25 @@ newTyVarWith _loc desc info = do pure v newTyVar :: (Located loc) => loc -> Name -> TermM TyVar -newTyVar loc desc = newTyVarWith loc desc TyVarFree +newTyVar loc desc = newTyVarWith desc $ TyVarFree $ locOf loc newType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) newType loc desc u = tyVarType u <$> newTyVar loc desc newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type newTypeWithField loc desc k t = - tyVarType NoUniqueness <$> newTyVarWith loc desc (TyVarRecord $ M.singleton k t) + tyVarType NoUniqueness + <$> newTyVarWith desc (TyVarRecord (locOf loc) $ M.singleton k t) newTypeWithConstr :: SrcLoc -> Name -> u -> Name -> [TypeBase SComp u] -> TermM (TypeBase d u) newTypeWithConstr loc desc u k ts = - tyVarType u <$> newTyVarWith loc desc (TyVarSum $ M.singleton k ts') + tyVarType u <$> newTyVarWith desc (TyVarSum (locOf loc) $ M.singleton k ts') where ts' = map (`setUniqueness` NoUniqueness) ts newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniqueness) newTypeOverloaded loc name pts = - tyVarType NoUniqueness <$> newTyVarWith loc name (TyVarPrim pts) + tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) newSVar :: (Located loc) => loc -> Name -> TermM SVar newSVar _loc desc = do @@ -1185,7 +1186,7 @@ checkValDef :: SrcLoc ) -> TypeM - ( Either T.Text ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), + ( Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), [Pat ParamType], Maybe (TypeExp Exp VName), Exp @@ -1226,7 +1227,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", let p (v, t) = prettyNameString v <> " => " <> prettyString t - in either T.unpack (unlines . map p . M.toList . snd) solution, + in either (docString . prettyTypeError) (unlines . map p . M.toList . snd) solution, either (const mempty) (unlines . ("## generalised:" :) . map prettyString . fst) solution ] pure (solution, params', retdecl', body'') @@ -1248,7 +1249,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do checkSingleExp :: ExpBase NoInfo VName -> - TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints @@ -1261,7 +1262,7 @@ checkSingleExp e = runTermM $ do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> - TypeM (Either T.Text ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints @@ -1276,4 +1277,4 @@ checkSizeExp e = runTermM $ do case (solutions, es') of ([solution], [e'']) -> pure (solution, e'') - _ -> pure (Left "Ambiguous size expression", e') + _ -> pure (Left $ TypeError (locOf e) mempty "Ambiguous size expression", e') From 56c99a89c8d32ccec5dc229519ed58b9444354ac Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 7 Mar 2024 05:23:31 -0800 Subject: [PATCH 175/296] More location info. --- src/Language/Futhark/TypeChecker/Rank.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index e42ea67d52..85e1766e13 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -344,9 +344,9 @@ addRankInfo t = do new_var = do t' <- newTyVar t old_tyvars <- asks envTyVars - let info = fromJust $ old_tyvars M.!? t - modify $ \s -> s {substTyVars = M.insert t' info $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (fst info, TyVarFree mempty) $ substTyVars s} + let (level, tvinfo) = fromJust $ old_tyvars M.!? t + modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree $ locOf tvinfo) $ substTyVars s} class SubstRanks a where substRanks :: (MonadTypeChecker m) => a -> SubstT m a From 05fe9a0a5673631e672ddc8d380b5aaa51159910 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 14:30:56 +0100 Subject: [PATCH 176/296] Also put locations in constraints. --- .../Futhark/TypeChecker/Constraints.hs | 77 +++++++++++------ src/Language/Futhark/TypeChecker/Rank.hs | 12 +-- src/Language/Futhark/TypeChecker/Terms2.hs | 86 +++++++++---------- 3 files changed, 99 insertions(+), 76 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e78bf993b2..eddc37691c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -33,6 +33,9 @@ newtype Reason = Reason } deriving (Eq, Ord, Show) +instance Located Reason where + locOf = reasonLoc + type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -61,13 +64,20 @@ toType :: TypeBase Size u -> TypeBase SComp u toType = first (const SDim) data Ct - = CtEq Type Type - | CtAM SVar SVar (Shape SComp) + = CtEq Reason Type Type + | CtAM Reason SVar SVar (Shape SComp) deriving (Show) +ctReason :: Ct -> Reason +ctReason (CtEq r _ _) = r +ctReason (CtAM r _ _ _) = r + +instance Located Ct where + locOf = locOf . ctReason + instance Pretty Ct where - pretty (CtEq t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" + pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 + pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -160,25 +170,25 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) -occursCheck :: VName -> Type -> SolveM () -occursCheck v tp = do +occursCheck :: Reason -> VName -> Type -> SolveM () +occursCheck reason v tp = do vars <- gets solverTyVars let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . throwError . TypeError mempty mempty $ + when (v `S.member` typeVars tp') . throwError . TypeError (locOf reason) mempty $ "Occurs check: cannot instantiate" <+> prettyName v <+> "with" <+> pretty tp <> "." -subTyVar :: VName -> Int -> Type -> SolveM () -subTyVar v lvl t = do - occursCheck v t +subTyVar :: Reason -> VName -> Int -> Type -> SolveM () +subTyVar reason v lvl t = do + occursCheck reason v t modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} -linkTyVar :: VName -> VName -> SolveM () -linkTyVar v t = do - occursCheck v $ Scalar $ TypeVar NoUniqueness (qualName t) [] +linkTyVar :: Reason -> VName -> VName -> SolveM () +linkTyVar reason v t = do + occursCheck reason v $ Scalar $ TypeVar NoUniqueness (qualName t) [] tyvars <- gets solverTyVars modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} tyvars' <- @@ -222,13 +232,18 @@ unify t1 t2 Just [(t1', t2')] unify _ _ = Nothing -solveCt :: Ct -> SolveM () -solveCt ct = - case ct of - CtEq t1 t2 -> solveCt' (t1, t2) - CtAM {} -> pure () -- Good vibes only. +solveEq :: Reason -> Type -> Type -> SolveM () +solveEq reason orig_t1 orig_t2 = do + solveCt' (orig_t1, orig_t2) where - bad = throwError $ TypeError mempty mempty $ "Unsolvable:" <+> pretty ct + cannotUnify = do + tyvars <- gets solverTyVars + throwError . TypeError (locOf reason) mempty $ + "Cannot unify" + indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) + "with" + indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) + solveCt' (t1, t2) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of @@ -249,22 +264,28 @@ solveCt ct = | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (Nothing, Nothing) -> bad - (Just lvl, Nothing) -> subTyVar v1 lvl t2' - (Nothing, Just lvl) -> subTyVar v2 lvl t1' + (Nothing, Nothing) -> cannotUnify + (Just lvl, Nothing) -> subTyVar reason v1 lvl t2' + (Nothing, Just lvl) -> subTyVar reason v2 lvl t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> linkTyVar v1 v2 - | otherwise -> linkTyVar v2 v1 + | lvl1 <= lvl2 -> linkTyVar reason v1 v2 + | otherwise -> linkTyVar reason v2 v1 (Scalar (TypeVar _ (QualName [] v1) []), t2') | Just lvl <- flexible v1 -> - subTyVar v1 lvl t2' + subTyVar reason v1 lvl t2' (t1', Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- flexible v2 -> - subTyVar v2 lvl t1' + subTyVar reason v2 lvl t1' (t1', t2') -> case unify t1' t2' of - Nothing -> bad + Nothing -> cannotUnify Just eqs -> mapM_ solveCt' eqs +solveCt :: Ct -> SolveM () +solveCt ct = + case ct of + CtEq reason t1 t2 -> solveEq reason t1 t2 + CtAM {} -> pure () -- Good vibes only. + solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () solveTyVar (_, (_, TyVarFree {})) = pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do @@ -286,7 +307,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do Just (Scalar (Record fs2)) | all (`M.member` fs2) (M.keys fs1) -> forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> - solveCt $ CtEq t1 t2 + solveCt $ CtEq (Reason loc) t1 t2 Just tv_t' -> throwError $ TypeError loc mempty $ diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 85e1766e13..9af5623d5e 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -72,7 +72,7 @@ instance Distribute (TypeBase dim u) where distribute t = t instance Distribute Ct where - distribute (CtEq t1 t2) = distribute t1 `CtEq` distribute t2 + distribute (CtEq r t1 t2) = CtEq r (distribute t1) (distribute t2) distribute c = c data RankState = RankState @@ -117,8 +117,8 @@ addObj sv = modify $ \s -> s {rankObj = rankObj s ~+~ var sv} addCt :: Ct -> RankM () -addCt (CtEq t1 t2) = addConstraint $ rank t1 ~==~ rank t2 -addCt (CtAM r m f) = do +addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 +addCt (CtAM _ r m f) = do b_r <- binVar r b_m <- binVar m b_max <- VName "c_max" <$> incCounter @@ -242,10 +242,11 @@ rankAnalysis loc cs tyVars params body = do cs' = foldMap (splitFuncs . distribute) cs splitFuncs ( CtEq + reason (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) ) = - splitFuncs (CtEq t1a t2a) ++ splitFuncs (CtEq t1r' t2r') + splitFuncs (CtEq reason t1a t2a) ++ splitFuncs (CtEq reason t1r' t2r') where t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness @@ -312,6 +313,7 @@ newTyVar t = do substNewCts = substNewCts s ++ [ CtEq + (Reason mempty) -- FIXME (Scalar (TypeVar mempty (QualName [] t) [])) (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) ] @@ -374,7 +376,7 @@ instance SubstRanks (TypeBase SComp u) where substRanks t = pure t instance SubstRanks Ct where - substRanks (CtEq t1 t2) = CtEq <$> substRanks t1 <*> substRanks t2 + substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" updAM :: Map VName Int -> Exp -> Exp diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d73c6758d9..6b1e06fc54 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -267,27 +267,27 @@ asStructType loc (Scalar (Sum cs)) = Scalar . Sum <$> traverse (mapM (asStructType loc)) cs asStructType loc t@(Scalar (TypeVar u _ _)) = do t' <- newType loc "artificial" u - ctEq (toType t') t + ctEq (Reason (locOf loc)) (toType t') t pure t' asStructType loc t@(Array u _ _) = do t' <- newType loc "artificial" u - ctEq (toType t') t + ctEq (Reason (locOf loc)) (toType t') t pure t' addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} -ctEq :: TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () -ctEq t1 t2 = +ctEq :: Reason -> TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () +ctEq reason t1 t2 = -- As a minor optimisation, do not add constraint if the types are -- equal. - unless (t1' == t2') $ addCt $ CtEq t1' t2' + unless (t1' == t2') $ addCt $ CtEq reason t1' t2' where t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness -ctAM :: SVar -> SVar -> Shape SComp -> TermM () -ctAM r m f = addCt $ CtAM r m f +ctAM :: Reason -> SVar -> SVar -> Shape SComp -> TermM () +ctAM reason r m f = addCt $ CtAM reason r m f localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -361,11 +361,11 @@ arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why [pt] e = do - ctEq (Scalar $ Prim pt) (expType e) + ctEq (Reason (locOf e)) (Scalar $ Prim pt) (expType e) pure e require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts - ctEq t $ expType e + ctEq (Reason (locOf e)) t $ expType e pure e -- | Instantiate a type scheme with fresh type variables for its type @@ -448,7 +448,7 @@ patLitMkType (PatLitPrim v) _ = checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp checkSizeExp' e = do e' <- checkExp e - ctEq (expType e') (Scalar (Prim (Signed Int64))) + ctEq (Reason (locOf e)) (expType e') (Scalar (Prim (Signed Int64))) pure e' checkPat' :: @@ -477,7 +477,7 @@ checkPat' (TuplePat ps loc) (Ascribed t) <*> pure loc | otherwise = do ps_t <- replicateM (length ps) (newType loc "t" Observe) - ctEq (toType (Scalar (tupleRecord ps_t))) (toType t) + ctEq (Reason (locOf loc)) (toType (Scalar (tupleRecord ps_t))) (toType t) TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc @@ -487,7 +487,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs - ctEq (Scalar (Record p_fs')) $ toType t + ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) $ toType t st <- asStructType loc $ Scalar (Record p_fs') checkPat' p $ Ascribed $ toParam Observe st where @@ -507,7 +507,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do case maybe_outer_t of Ascribed outer_t -> do - ctEq (toType st') (toType outer_t) + ctEq (Reason (locOf loc)) (toType st') (toType outer_t) PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -519,7 +519,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - ctEq (toType t') (toType t) + ctEq (Reason (locOf loc)) (toType t') (toType t) pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -542,7 +542,7 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do p_t <- newType (srclocOf p) "t" Observe checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' - ctEq t' (toType t) + ctEq (Reason (locOf loc)) t' (toType t) t'' <- asStructType loc t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do @@ -640,8 +640,8 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do m_var = Var (QualName [] m) unit_info mempty lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m)) a - ctAM r m $ fmap toSComp (toShape m_var <> fframe) - ctEq lhs rhs + ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) + ctEq (Reason (locOf loc)) lhs rhs debugTraceM 3 $ unlines [ "## checkApplyOne", @@ -688,7 +688,7 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do split ftype' = do a <- newType loc "arg" NoUniqueness b <- newType loc "res" Nonunique - ctEq ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] @@ -708,17 +708,17 @@ isSlice DimFix {} = False -- Add constraints saying that the first type has a (potentially -- nested) field containing the second type. mustHaveFields :: SrcLoc -> Type -> [Name] -> Type -> TermM () -mustHaveFields _ t [] ve_t = +mustHaveFields loc t [] ve_t = -- This case is probably never reached. - ctEq t ve_t + ctEq (Reason (locOf loc)) t ve_t mustHaveFields loc t [f] ve_t = do rt :: Type <- newTypeWithField loc "ft" f ve_t - ctEq t rt + ctEq (Reason (locOf loc)) t rt mustHaveFields loc t (f : fs) ve_t = do ft <- newType loc "ft" NoUniqueness rt <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t - ctEq t rt + ctEq (Reason (locOf loc)) t rt checkCase :: StructType -> @@ -741,7 +741,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq (toType c_t) (toType cs_t) + ctEq (Reason (locOf c)) (toType c_t) (toType cs_t) pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -779,7 +779,7 @@ checkRetDecl :: checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (expType body) (toType st) + ctEq (Reason (locOf body)) (expType body) (toType st) pure $ Just te' checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) @@ -830,7 +830,7 @@ checkExp (ArrayLit es _ loc) = do et <- newType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e - ctEq (expType e') (toType et) + ctEq (Reason (locOf loc)) (expType e') (toType et) pure e' let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et pure $ ArrayLit es' (Info arr_t) loc @@ -1003,19 +1003,19 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- require "use in range expression" anyIntType =<< checkExp start let check e = do e' <- checkExp e - ctEq (expType start') (expType e') + ctEq (Reason (locOf e')) (expType start') (expType e') pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end range_t <- newType loc "range" NoUniqueness - ctEq (toType range_t) (arrayOfRank 1 (expType start')) + ctEq (Reason (locOf start')) (toType range_t) (arrayOfRank 1 (expType start')) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e kt <- newType loc "kt" NoUniqueness t <- newTypeWithField loc "t" k kt - ctEq (expType e') t + ctEq (Reason (locOf e')) (expType e') t kt' <- asStructType loc kt pure $ Project k e' (Info kt') loc -- @@ -1031,8 +1031,8 @@ checkExp (IndexSection slice NoInfo loc) = do index_elem_t <- newType loc "index_elem" NoUniqueness index_res_t <- newType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (toType index_arg_t) $ arrayOfRank num_slices index_elem_t - ctEq index_res_t $ arrayOfRank (length slice) index_elem_t + ctEq (Reason (locOf loc)) (toType index_arg_t) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t index_res_t' <- asStructType loc index_res_t let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' pure $ IndexSection slice' (Info ft) loc @@ -1043,8 +1043,8 @@ checkExp (AppExp (Index e slice loc) _) = do index_t <- newType loc "index" NoUniqueness index_elem_t <- newType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (toType index_t) $ arrayOfRank num_slices index_elem_t - ctEq (expType e') $ arrayOfRank (length slice) index_elem_t + ctEq (Reason (locOf loc)) (toType index_t) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf e')) (expType e') $ arrayOfRank (length slice) index_elem_t pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) -- checkExp (Update src slice ve loc) = do @@ -1053,8 +1053,8 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (expType src') $ arrayOfRank (length slice) update_elem_t - ctEq (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf src')) (expType src') $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do @@ -1065,8 +1065,8 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (toType src_t) $ arrayOfRank (length slice) update_elem_t - ctEq (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf loc)) (toType src_t) $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) @@ -1076,8 +1076,8 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do e2' <- checkExp e2 e3' <- checkExp e3 - ctEq (expType e1') (Scalar (Prim Bool)) - ctEq (expType e2') (expType e3') + ctEq (Reason (locOf e1')) (expType e1') (Scalar (Prim Bool)) + ctEq (Reason (locOf loc)) (expType e2') (expType e3') pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) -- @@ -1096,17 +1096,17 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do let i' = Ident i (Info (typeOf bound')) iloc bind [i'] $ do body' <- checkExp body - ctEq (expType arg') (expType body') + ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (For i' bound', body') While cond -> do cond' <- checkExp cond body' <- checkExp body - ctEq (expType arg') (expType body') + ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" NoUniqueness - ctEq (expType arr') $ arrayOfRank 1 (toType elem_t) + ctEq (Reason (locOf arr')) (expType arr') $ arrayOfRank 1 (toType elem_t) bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') @@ -1118,12 +1118,12 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do checkExp (Ascript e te loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (expType e') (toType st) + ctEq (Reason (locOf e')) (expType e') (toType st) pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (expType e') (toType st) + ctEq (Reason (locOf e')) (expType e') (toType st) pure $ Coerce e' te' (Info (toStruct st)) loc doDefault :: From ef20fbe82a1f8df507874b913ecfc17fa7091b43 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 7 Mar 2024 05:38:00 -0800 Subject: [PATCH 177/296] Easy fix. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 9af5623d5e..7443c38a47 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -47,7 +47,7 @@ instance Rank (Shape SComp) where instance Rank ScalarType where rank Prim {} = constant 0 - rank (TypeVar _ (QualName [] v) []) = var v -- FIXME - might not be a type variable. + rank (TypeVar _ (QualName [] v) []) = var v rank (TypeVar {}) = constant 0 rank (Arrow {}) = constant 0 rank (Record {}) = constant 0 From 37cc601699cde242d57a42297c01efbea8b253f9 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Thu, 7 Mar 2024 05:41:20 -0800 Subject: [PATCH 178/296] More FIXME extermination. --- src/Language/Futhark/TypeChecker/Rank.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 7443c38a47..5f3775a748 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -307,13 +307,14 @@ newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do t' <- lift $ newTypeName (baseName t) shape <- rankToShape t + loc <- (locOf . snd . fromJust . (M.!? t)) <$> asks envTyVars modify $ \s -> s { substNewVars = M.insert t t' $ substNewVars s, substNewCts = substNewCts s ++ [ CtEq - (Reason mempty) -- FIXME + (Reason loc) (Scalar (TypeVar mempty (QualName [] t) [])) (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) ] From 75e5be4b5e96ccbe8f6df54d8b622f7510e8475a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 15:28:24 +0100 Subject: [PATCH 179/296] Proper AUTOMAP for single expressions. --- src/Language/Futhark/TypeChecker/Terms2.hs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6b1e06fc54..5252a93de9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1249,14 +1249,27 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do checkSingleExp :: ExpBase NoInfo VName -> - TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - solution <- - bitraverse pure (traverse (doDefaults mempty)) $ solve cts tyvars - pure (solution, e') + ((cts', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars [] e' + case solve cts' tyvars' of + Left err -> pure (Left err, e'') + Right (unconstrained, solution) -> do + let (generalised, unconstrained') = + generalise (typeOf e'') unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( Right + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ), + e'' + ) -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. From 941dec2237f5c3decb344a5cbff88f695b2916d4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 15:40:39 +0100 Subject: [PATCH 180/296] Has to be written like this. --- tests/automap/mri-q.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index 3a4648c7b9..f53b5df7a6 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -18,7 +18,7 @@ def main_orig [numK][numX] let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs in (qr, qi) - + def main_am [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) @@ -31,11 +31,11 @@ def main_am [numK][numX] let qr = f32.sum (f32.cos expArgs * phiMag) let qi = f32.sum (f32.sin expArgs * phiMag) in (qr, qi) - + entry main [numK][numX] (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) (phiR: [numK]f32) (phiI: [numK]f32) = let (qr, qi) = main_orig kx ky kz x y z phiR phiI let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI - in and (map2 (==) qr qr_am && qi == qi_am) + in and (map2 (==) qr qr_am && map2 (==) qi qi_am) From 82c8458dc7dda7dbf2f6a4ca171ebcea37ce06d1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 7 Mar 2024 15:40:47 +0100 Subject: [PATCH 181/296] Reduce duplication. --- src/Language/Futhark/TypeChecker/Terms2.hs | 40 ++++++++++------------ 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 5252a93de9..4226fef858 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1177,6 +1177,22 @@ generalise fun_t unconstrained solution = Just (Right t) -> foldMap expandTyVars $ typeVars t _ -> S.singleton v +generaliseAndDefaults :: + [VName] -> + Solution -> + StructType -> + TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) +generaliseAndDefaults unconstrained solution t = do + let (generalised, unconstrained') = + generalise t unconstrained solution + solution' <- doDefaults (map typeParamName generalised) solution + pure + ( generalised, + -- See #1552 for why we resolve unconstrained and + -- un-generalised type variables to (). + M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' + ) + checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1237,15 +1253,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do foldFunType (map patternType params') (RetType [] $ toRes Nonunique (typeOf body')) - (generalised, unconstrained') = - generalise fun_t unconstrained solution - solution' <- doDefaults (map typeParamName generalised) solution - pure - ( generalised, - -- See #1552 for why we resolve unconstrained and - -- un-generalised type variables to (). - M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' - ) + generaliseAndDefaults unconstrained solution fun_t checkSingleExp :: ExpBase NoInfo VName -> @@ -1258,18 +1266,8 @@ checkSingleExp e = runTermM $ do case solve cts' tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do - let (generalised, unconstrained') = - generalise (typeOf e'') unconstrained solution - solution' <- doDefaults (map typeParamName generalised) solution - pure - ( Right - ( generalised, - -- See #1552 for why we resolve unconstrained and - -- un-generalised type variables to (). - M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' - ), - e'' - ) + x <- generaliseAndDefaults unconstrained solution $ typeOf e'' + pure (Right x, e'') -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. From f0bef225284a8e07efcfb55fc81e8434084b9a36 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 11:47:22 +0100 Subject: [PATCH 182/296] Work on sum types. --- .../Futhark/TypeChecker/Constraints.hs | 33 ++++++++++++++----- .../Futhark/TypeChecker/Terms/Monad.hs | 1 + src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 ++++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index eddc37691c..71765d7e0d 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -289,8 +289,8 @@ solveCt ct = solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () solveTyVar (_, (_, TyVarFree {})) = pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do - t <- lookupTyVar tv - case t of + tv_t <- lookupTyVar tv + case tv_t of Nothing -> pure () Just t' | t' `elem` map (Scalar . Prim) pts -> pure () @@ -309,12 +309,29 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> solveCt $ CtEq (Reason loc) t1 t2 Just tv_t' -> - throwError $ - TypeError loc mempty $ - "Type must be record with fields" - indent 2 (pretty (Scalar (Record fs1))) - "but inferred to be" - indent 2 (pretty tv_t') + throwError . TypeError loc mempty $ + "Type must be record with fields" + indent 2 (pretty (Scalar (Record fs1))) + "but inferred to be" + indent 2 (pretty tv_t') +solveTyVar (tv, (_, TyVarSum loc cs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Nothing -> pure () + Just (Scalar (Sum cs2)) + | all (`M.member` cs2) (M.keys cs1), + cs3 <- M.toList $ M.intersectionWith (,) cs1 cs2, + all (sameLength . snd) cs3 -> + forM_ cs3 $ \(_k, (t1s, t2s)) -> + mapM_ solveCt $ zipWith (CtEq (Reason loc)) t1s t2s + Just tv_t' -> + throwError . TypeError loc mempty $ + "Type must be sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) + "but inferred to be" + indent 2 (pretty tv_t') + where + sameLength (x, y) = length x == length y solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) solve constraints tyvars = diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 018a3e920a..42c5d53ab2 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -87,6 +87,7 @@ unusedSize p = data Inferred t = NoneInferred | Ascribed t + deriving (Show) instance Functor Inferred where fmap _ NoneInferred = NoneInferred diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index b86dd63616..265882c541 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -178,9 +178,19 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' _ (PatLit l info loc) _ = pure $ PatLit l info loc -checkPat' sizes (PatConstr n info ps loc) _ = do +checkPat' sizes (PatConstr n info ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps pure $ PatConstr n info ps' loc +checkPat' sizes (PatConstr n info ps loc) (Ascribed (Scalar (Sum cs))) + | Just ts <- M.lookup n cs = do + ps' <- zipWithM (\p t -> checkPat' sizes p (Ascribed t)) ps ts + pure $ PatConstr n info ps' loc +checkPat' _ p t = + error . unlines $ + [ "checkPat': bad case", + prettyString p, + show t + ] checkPat :: [(SizeBinder VName, QualName VName)] -> From cefc91b0071ff6d30bfa0c5725b9617e005b3f51 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 12:59:01 +0100 Subject: [PATCH 183/296] Fix for-in loops. --- src/Language/Futhark/TypeChecker/Terms2.hs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 4226fef858..7cd3f10a1c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1096,12 +1096,10 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do let i' = Ident i (Info (typeOf bound')) iloc bind [i'] $ do body' <- checkExp body - ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (For i' bound', body') While cond -> do cond' <- checkExp cond body' <- checkExp body - ctEq (Reason (locOf arg')) (expType arg') (expType body') pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr @@ -1110,6 +1108,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do bindLetPat elemp elem_t $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') + ctEq (Reason (locOf loc)) (expType arg') (expType body') pure $ AppExp (Loop [] pat' arg' form' body' loc) From 80c6d80dc06bef4d0ea57032c46fde262c78417f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 13:19:04 +0100 Subject: [PATCH 184/296] Detect duplicate fields. --- src/Language/Futhark/TypeChecker/Terms2.hs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 7cd3f10a1c..13c9ef2cd8 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,7 +59,7 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames qualified as FreshNames import Futhark.MonadFreshNames hiding (newName) -import Futhark.Util (debugTraceM, mapAccumLM) +import Futhark.Util (debugTraceM, mapAccumLM, nubOrd) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Constraints @@ -481,6 +481,16 @@ checkPat' (TuplePat ps loc) (Ascribed t) TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc +checkPat' p@(RecordPat p_fs loc) _ + | Just (f, fp) <- L.find (("_" `T.isPrefixOf`) . nameToText . fst) p_fs = + typeError fp mempty $ + "Underscore-prefixed fields are not allowed." + "Did you mean" + <> dquotes (pretty (T.drop 1 (nameToText f)) <> "=_") + <> "?" + | nubOrd (map fst p_fs) /= map fst p_fs = + typeError loc mempty $ + "Duplicate fields in record pattern" <+> pretty p <> "." checkPat' p@(RecordPat p_fs loc) (Ascribed t) | Scalar (Record t_fs) <- t, L.sort (map fst p_fs) == L.sort (M.keys t_fs) = From c1fb806f8e51b132d248d5a09cba1213bb52d1ef Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 15:05:27 +0100 Subject: [PATCH 185/296] Handle type arguments here. --- src/Language/Futhark/TypeChecker/Types.hs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index b89802c5e0..02126dbfd8 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -534,10 +534,13 @@ substTypesAny lookupSubst ot = -- | Substitution without caring about sizes. substTyVars :: (Monoid u) => (VName -> Maybe (TypeBase d NoUniqueness)) -> TypeBase d u -> TypeBase d u -substTyVars f t@(Scalar (TypeVar u (QualName qs v) args)) = - case f v of +substTyVars f (Scalar (TypeVar u qn args)) = + case f $ qualLeaf qn of Just t' -> second (const mempty) $ substTyVars f t' - Nothing -> t + Nothing -> Scalar (TypeVar u qn (map onArg args)) + where + onArg (TypeArgType t) = TypeArgType $ substTyVars f t + onArg (TypeArgDim e) = TypeArgDim e substTyVars _ (Scalar (Prim pt)) = Scalar $ Prim pt substTyVars f (Scalar (Record fs)) = Scalar $ Record $ M.map (substTyVars f) fs substTyVars f (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map $ substTyVars f) cs From c4e11a9dcb81384f7a1b28e2923e984717393211 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 15:10:28 +0100 Subject: [PATCH 186/296] let should not be generalised. --- docs/language-reference.rst | 8 +++++--- tests/types/inference22.fut | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/language-reference.rst b/docs/language-reference.rst index 51719d62a3..b9d3b6ff37 100644 --- a/docs/language-reference.rst +++ b/docs/language-reference.rst @@ -994,9 +994,11 @@ Syntactic sugar for ``let a = a with [i] = v in a``. ............................... Bind ``f`` to a function with the given parameters and definition -(``e``) and evaluate ``body``. The function will be treated as -aliasing any free variables in ``e``. The function is not in scope of -itself, and hence cannot be recursive. +(``e``) and evaluate ``body``. The function will be treated as +aliasing any free variables in ``e``. The function is not in scope of +itself, and hence cannot be recursive. While the function can be made +polymorphic by putting in explicit size parameters, it is not +automatically generalised the way top level functions are. ``loop pat = initial for x in a do loopbody`` ............................................. diff --git a/tests/types/inference22.fut b/tests/types/inference22.fut index 4e367db82f..dbe574e411 100644 --- a/tests/types/inference22.fut +++ b/tests/types/inference22.fut @@ -2,5 +2,5 @@ -- == def main (x: i32) (y: bool) = - let f x y = (y,x) + let f 'a 'b (x: a) (y: b) = (y,x) in (f x y, f y x) From 8cb7c7d710176b3e4a3df6df9fc30781a1960d17 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Mon, 11 Mar 2024 07:27:05 -0700 Subject: [PATCH 187/296] Stop erroneously changing the type of automapped functions. --- src/Futhark/Internalise/FullNormalise.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 9b27a6cf46..789e0d3c85 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -507,14 +507,13 @@ expandAMAnnotations e = do arg_es' <- mapM expandAMAnnotations arg_es let diets = funDiets $ typeOf f withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do - inner_f <- setNewType f' $ innerFType (typeOf f') ams let rettype = - case unfoldFunTypeWithRet $ typeOf inner_f of + case unfoldFunTypeWithRet $ typeOf f' of Nothing -> error "Function type expected." Just (ptypes, f_ret) -> foldFunType (drop (length args') ptypes) f_ret pure $ - mkApply inner_f (zip3 exts (repeat mempty) args') $ + mkApply f' (zip3 exts (repeat mempty) args') $ res {appResType = rettype} (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x From c6fa39eb066536cb55ac95ed965c49a10b85b578 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 16:31:56 +0100 Subject: [PATCH 188/296] Some sum fixes. --- src/Language/Futhark/TypeChecker/Terms.hs | 6 +++--- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index c54c0b5805..b3afad3954 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -783,10 +783,10 @@ checkExp (AppExp (Loop _ mergepat mergeexp form loopbody loc) _) = do AppExp (Loop sparams mergepat' mergeexp' form' loopbody' loc) (Info appres) -checkExp (Constr name es _ loc) = do - t <- newTypeVar loc "t" +checkExp (Constr name es (Info t) loc) = do + t' <- replaceTyVars loc t es' <- mapM checkExp es - pure $ Constr name es' (Info t) loc + pure $ Constr name es' (Info t') loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e mt <- expType e' diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 265882c541..2a648ea8ab 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -181,10 +181,10 @@ checkPat' _ (PatLit l info loc) _ = checkPat' sizes (PatConstr n info ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps pure $ PatConstr n info ps' loc -checkPat' sizes (PatConstr n info ps loc) (Ascribed (Scalar (Sum cs))) +checkPat' sizes (PatConstr n _ ps loc) (Ascribed (Scalar (Sum cs))) | Just ts <- M.lookup n cs = do ps' <- zipWithM (\p t -> checkPat' sizes p (Ascribed t)) ps ts - pure $ PatConstr n info ps' loc + pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' _ p t = error . unlines $ [ "checkPat': bad case", From 4de7ecf60a7cba572412b28dac8eaf5fd2b3c51d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 16:39:41 +0100 Subject: [PATCH 189/296] Detect more ambiguities. --- src/Language/Futhark/TypeChecker/Constraints.hs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 71765d7e0d..461b256e51 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -303,7 +303,11 @@ solveTyVar (tv, (_, TyVarPrim loc pts)) = do solveTyVar (tv, (_, TyVarRecord loc fs1)) = do tv_t <- lookupTyVar tv case tv_t of - Nothing -> pure () + Nothing -> + throwError . TypeError loc mempty $ + "Type is ambiguous." + "Must be a record with fields" + indent 2 (pretty (Scalar (Record fs1))) Just (Scalar (Record fs2)) | all (`M.member` fs2) (M.keys fs1) -> forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> @@ -317,7 +321,11 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do solveTyVar (tv, (_, TyVarSum loc cs1)) = do tv_t <- lookupTyVar tv case tv_t of - Nothing -> pure () + Nothing -> + throwError . TypeError loc mempty $ + "Type is ambiguous." + "Must be a sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) Just (Scalar (Sum cs2)) | all (`M.member` cs2) (M.keys cs1), cs3 <- M.toList $ M.intersectionWith (,) cs1 cs2, From c4610ba6e37ef4eb0ac248f23fa3215b655cb141 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 17:03:17 +0100 Subject: [PATCH 190/296] Add notion of equality type. --- src/Language/Futhark/TypeChecker/Constraints.hs | 6 ++++++ src/Language/Futhark/TypeChecker/Rank.hs | 2 ++ src/Language/Futhark/TypeChecker/Terms2.hs | 3 ++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 461b256e51..4a7d30ae98 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -93,6 +93,8 @@ data TyVarInfo TyVarRecord Loc (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum Loc (M.Map Name [Type]) + | -- | Must be a type that supports equality. + TyVarEql Loc deriving (Show, Eq) instance Pretty TyVarInfo where @@ -100,12 +102,14 @@ instance Pretty TyVarInfo where pretty (TyVarPrim _ pts) = "∈" <+> pretty pts pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs + pretty (TyVarEql _) = "equality" instance Located TyVarInfo where locOf (TyVarFree loc) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc + locOf (TyVarEql loc) = loc type TyVar = VName @@ -340,6 +344,8 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do indent 2 (pretty tv_t') where sameLength (x, y) = length x == length y +solveTyVar (_, (_, TyVarEql _)) = + pure () solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) solve constraints tyvars = diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 5f3775a748..3826de135c 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -137,6 +137,8 @@ addTyVarInfo tv (_, TyVarRecord {}) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarEql {}) = + addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: [Ct] -> TyVars -> LinearProg mkLinearProg cs tyVars = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 13c9ef2cd8..51bac62776 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -410,7 +410,8 @@ lookupVar loc qn@(QualName qs name) = do outer_env <- asks termOuterEnv asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do - argtype <- newTypeOverloaded loc "t" anyPrimType + argtype <- + tyVarType NoUniqueness <$> newTyVarWith "t" (TyVarEql (locOf loc)) pure $ foldFunType [toParam Observe argtype, toParam Observe argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts From 521846f08121a09ce9251fc6c1e4e54b47e6798d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 11 Mar 2024 17:04:17 +0100 Subject: [PATCH 191/296] Remove duplicate comment. --- src/Language/Futhark/Prop.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index d45cc3d294..154fce6e59 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -1067,8 +1067,6 @@ intrinsics = ++ [Bool] ) ++ - -- This overrides the ! from Primitive. - -- This overrides the ! from Primitive. [ ( "!", IntrinsicOverloadedFun From 300cce6492d1cd9ff173279b0f82bfaeb1baa8f2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 11:53:47 +0100 Subject: [PATCH 192/296] Fix check. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 51bac62776..94ffe58977 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -489,7 +489,7 @@ checkPat' p@(RecordPat p_fs loc) _ "Did you mean" <> dquotes (pretty (T.drop 1 (nameToText f)) <> "=_") <> "?" - | nubOrd (map fst p_fs) /= map fst p_fs = + | length (nubOrd (map fst p_fs)) /= length (map fst p_fs) = typeError loc mempty $ "Duplicate fields in record pattern" <+> pretty p <> "." checkPat' p@(RecordPat p_fs loc) (Ascribed t) From 6a690521dec2ec33f19c7c867f6de5b3445ace15 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:19:36 +0100 Subject: [PATCH 193/296] Avoid artificial type variables in constraints. --- src/Language/Futhark/Prop.hs | 6 +- src/Language/Futhark/TypeChecker/Rank.hs | 30 +- src/Language/Futhark/TypeChecker/Terms2.hs | 408 +++++++++++++-------- 3 files changed, 280 insertions(+), 164 deletions(-) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 154fce6e59..0507139fdc 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -497,7 +497,7 @@ typeOf (Attr _ e _) = typeOf e typeOf (AppExp _ (Info res)) = appResType res -- | The type of a function with the given parameters and return type. -funType :: [Pat ParamType] -> ResRetType -> StructType +funType :: [Pat (TypeBase d Diet)] -> RetTypeBase d Uniqueness -> TypeBase d NoUniqueness funType params ret = let RetType _ t = foldr (arrow . patternParam) ret params in toStruct t @@ -507,7 +507,7 @@ funType params ret = -- | @foldFunType ts ret@ creates a function type ('Arrow') that takes -- @ts@ as parameters and returns @ret@. -foldFunType :: [ParamType] -> ResRetType -> StructType +foldFunType :: [TypeBase d Diet] -> RetTypeBase d Uniqueness -> TypeBase d NoUniqueness foldFunType ps ret = let RetType _ t = foldr arrow ret ps in toStruct t @@ -621,7 +621,7 @@ patternStructType = toStruct . patternType -- | When viewed as a function parameter, does this pattern correspond -- to a named parameter of some type? -patternParam :: Pat ParamType -> (PName, Diet, StructType) +patternParam :: Pat (TypeBase d Diet) -> (PName, Diet, TypeBase d NoUniqueness) patternParam (PatParens p _) = patternParam p patternParam (PatAttr _ p _) = diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 3826de135c..fba519b544 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -218,13 +218,13 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), [Pat ParamType], Exp) -rankAnalysis1 loc cs tyVars params body = do - solutions <- rankAnalysis loc cs tyVars params body +rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp) +rankAnalysis1 loc cs tyVars artificial params body = do + solutions <- rankAnalysis loc cs tyVars artificial params body case solutions of [sol] -> pure sol sols -> do - let (_, _, bodies') = unzip3 sols + let (_, _, _, bodies') = L.unzip4 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -232,14 +232,15 @@ rankAnalysis1 loc cs tyVars params body = do ] ++ map pretty bodies' -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), [Pat ParamType], Exp)] -rankAnalysis _ [] tyVars params body = pure [(([], tyVars), params, body)] -rankAnalysis loc cs tyVars params body = do +rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp)] +rankAnalysis _ [] tyVars artificial params body = pure [(([], tyVars), artificial, params, body)] +rankAnalysis loc cs tyVars artificial params body = do rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps - pure $ zip3 cts_tyvars' params' bodys + artificial' <- mapM (substRankInfoArtificial tyVars artificial) rank_maps + pure $ L.zip4 cts_tyvars' artificial' params' bodys where cs' = foldMap (splitFuncs . distribute) cs splitFuncs @@ -265,6 +266,12 @@ substRankInfo cs tyVars rankmap = do isCtAM (CtAM {}) = True isCtAM _ = False +substRankInfoArtificial :: (MonadTypeChecker m) => TyVars -> M.Map VName Type -> Map VName Int -> m (M.Map VName Type) +substRankInfoArtificial tyvars artificial rankmap = do + (artificial', _, _) <- + runSubstT tyvars rankmap $ traverse substRanks artificial + pure artificial' + runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do let env = @@ -372,6 +379,7 @@ instance SubstRanks (TypeBase SComp u) where ta' <- substRanks ta tr' <- substRanks tr pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) + substRanks (Scalar (Record fs)) = Scalar . Record <$> traverse substRanks fs substRanks (Array u shape t) = do shape' <- substRanks shape t' <- substRanks $ Scalar t @@ -391,15 +399,15 @@ updAM rank_map e = in AppExp (Apply f' args' loc) res AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res - _ -> runIdentity $ astMap m e + _ -> runIdentity $ astMap mapper e where dimToRank (Var (QualName [] x) _ _) = replicate (rank_map M.! x) (TupLit mempty mempty) - dimToRank e = error $ prettyString e + dimToRank e' = error $ prettyString e' shapeToRank = Shape . foldMap dimToRank upd (AutoMap r m f) = AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) - m = + mapper = identityMapper { mapOnExp = pure . updAM rank_map } diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 94ffe58977..6839cd3d29 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -79,14 +79,11 @@ instance Functor Inferred where fmap f (Ascribed t) = Ascribed (f t) data ValBinding - = BoundV [TypeParam] StructType + = BoundV [TypeParam] Type | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) | EqualityF deriving (Show) -expType :: Exp -> Type -expType = toType . typeOf - data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, scopeTypeTable :: M.Map VName TypeBinding, @@ -116,7 +113,9 @@ data TermState = TermState termTyVars :: TyVars, termCounter :: !Int, termWarnings :: Warnings, - termNameSource :: VNameSource + termNameSource :: VNameSource, + -- | Mapping from artificial type variables to the actual types they represent. + termArtificial :: M.Map TyVar Type } newtype TermM a @@ -143,7 +142,7 @@ envToTermScope env = } where vtable = M.map valBinding $ envVtable env - valBinding (TypeM.BoundV tps v) = BoundV tps v + valBinding (TypeM.BoundV tps v) = BoundV tps $ toType v initialTermScope :: TermScope initialTermScope = @@ -169,7 +168,7 @@ initialTermScope = addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = Just ( name, - BoundV tvs $ foldFunType pts rt + BoundV tvs $ toType $ foldFunType pts rt ) addIntrinsicF (name, IntrinsicEquality) = Just (name, EqualityF) @@ -194,7 +193,8 @@ runTermM (TermM m) = do termTyVars = mempty, termWarnings = mempty, termNameSource = src, - termCounter = 0 + termCounter = 0, + termArtificial = mempty } case runExcept (runStateT (runReaderT m initial_env) initial_state) of Left (ws, e) -> do @@ -254,26 +254,42 @@ newSVar _loc desc = do i <- incCounter newID $ mkTypeVarName desc i -asStructType :: SrcLoc -> TypeBase SComp u -> TermM (TypeBase Size u) -asStructType _ (Scalar (Prim pt)) = pure $ Scalar $ Prim pt -asStructType _ (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] -asStructType loc (Scalar (Arrow u pname d t1 (RetType ext t2))) = do - t1' <- asStructType loc t1 - t2' <- asStructType loc t2 - pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' -asStructType loc (Scalar (Record fs)) = - Scalar . Record <$> traverse (asStructType loc) fs -asStructType loc (Scalar (Sum cs)) = - Scalar . Sum <$> traverse (mapM (asStructType loc)) cs -asStructType loc t@(Scalar (TypeVar u _ _)) = do - t' <- newType loc "artificial" u - ctEq (Reason (locOf loc)) (toType t') t - pure t' -asStructType loc t@(Array u _ _) = do - t' <- newType loc "artificial" u - ctEq (Reason (locOf loc)) (toType t') t +newArtificial :: u -> TypeBase SComp u -> TermM (TypeBase Size u) +newArtificial u t = do + v <- newID "artificial" + let t' = tyVarType u v + modify $ \s -> s {termArtificial = M.insert v (second (const NoUniqueness) t) $ termArtificial s} pure t' +-- The AST requires annotations to be StructTypes, but the type +-- checker works with Types. This creates artificial type "variables" +-- that allow us to connect the AST annotations with the actual +-- inferred types. The artificial variables should never occur in +-- constraints - they can be substituted away with asType. +asStructType :: TypeBase SComp u -> TermM (TypeBase Size u) +asStructType (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +asStructType (Scalar (TypeVar u v [])) = pure $ Scalar $ TypeVar u v [] +asStructType (Scalar (Arrow u pname d t1 (RetType ext t2))) = do + t1' <- asStructType t1 + t2' <- asStructType t2 + pure $ Scalar $ Arrow u pname d t1' $ RetType ext t2' +asStructType (Scalar (Record fs)) = + Scalar . Record <$> traverse asStructType fs +asStructType (Scalar (Sum cs)) = + Scalar . Sum <$> traverse (mapM asStructType) cs +asStructType t@(Scalar (TypeVar u _ _)) = + newArtificial u t +asStructType t@(Array u _ _) = do + newArtificial u t + +asType :: (Monoid u) => TypeBase Size u -> TermM (TypeBase SComp u) +asType t = do + artificial <- gets termArtificial + pure $ substTyVars (`M.lookup` artificial) (toType t) + +expType :: Exp -> TermM Type +expType = asType . typeOf -- NOTE: Only place you should use typeOf. + addCt :: Ct -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} @@ -336,8 +352,10 @@ instance MonadTypeChecker TermM where i <- incCounter newID $ mkTypeVarName name i - bindVal v (TypeM.BoundV tps t) = localScope $ \scope -> - scope {scopeVtable = M.insert v (BoundV tps t) $ scopeVtable scope} + bindVal v (TypeM.BoundV tps t) m = do + t' <- asType t + let f scope = scope {scopeVtable = M.insert v (BoundV tps t') $ scopeVtable scope} + localScope f m lookupType qn = do outer_env <- asks termOuterEnv @@ -361,11 +379,13 @@ arrayOfRank n = arrayOf $ Shape $ replicate n SDim require :: T.Text -> [PrimType] -> Exp -> TermM Exp require _why [pt] e = do - ctEq (Reason (locOf e)) (Scalar $ Prim pt) (expType e) + e_t <- expType e + ctEq (Reason (locOf e)) (Scalar $ Prim pt) e_t pure e require _why pts e = do t :: Type <- newTypeOverloaded (srclocOf e) "t" pts - ctEq (Reason (locOf e)) t $ expType e + e_t <- expType e + ctEq (Reason (locOf e)) t e_t pure e -- | Instantiate a type scheme with fresh type variables for its type @@ -375,18 +395,18 @@ instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> - StructType -> - TermM ([VName], StructType) + Type -> + TermM ([VName], Type) instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> case tparam of TypeParamType _ v _ -> do v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v - pure $ Just (v, (typeParamName tparam, Subst [] $ RetType [] $ tyVarType NoUniqueness v')) + pure $ Just (v, (typeParamName tparam, tyVarType NoUniqueness v')) TypeParamDim {} -> pure Nothing - let t' = applySubst (`lookup` substs) t + let t' = substTyVars (`lookup` substs) t pure (names, t') lookupMod :: QualName VName -> TermM Mod @@ -396,7 +416,7 @@ lookupMod qn@(QualName _ name) = do Nothing -> error $ "lookupMod: " <> show qn Just m -> pure m -lookupVar :: SrcLoc -> QualName VName -> TermM StructType +lookupVar :: SrcLoc -> QualName VName -> TermM Type lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn case M.lookup name $ scopeVtable scope of @@ -407,16 +427,15 @@ lookupVar loc qn@(QualName qs name) = do then pure t else do (tnames, t') <- instTypeScheme qn loc tparams t - outer_env <- asks termOuterEnv - asStructType loc $ first (const SDim) $ qualifyTypeVars outer_env tnames qs t' + -- TODO - qualify type names, like in the old type checker. + pure t' Just EqualityF -> do - argtype <- - tyVarType NoUniqueness <$> newTyVarWith "t" (TyVarEql (locOf loc)) - pure $ foldFunType [toParam Observe argtype, toParam Observe argtype] $ RetType [] $ Scalar $ Prim Bool + argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarEql (locOf loc)) + pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts let (pts', rt') = instOverloaded argtype pts rt - pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' + pure $ foldFunType (map (second $ const Observe) pts') $ RetType [] $ second (const Nonunique) rt' where instOverloaded argtype pts rt = ( map (maybe argtype (Scalar . Prim)) pts, @@ -427,13 +446,16 @@ bind :: [Ident StructType] -> TermM a -> TermM a -bind idents = localScope (`bindVars` idents) +bind idents m = do + let names = map identName idents + ts <- mapM (asType . unInfo . identType) idents + localScope (`bindVars` zip names ts) m where bindVars = foldl bindVar - bindVar scope (Ident name (Info tp) _) = + bindVar scope (name, t) = scope - { scopeVtable = M.insert name (BoundV [] tp) $ scopeVtable scope + { scopeVtable = M.insert name (BoundV [] t) $ scopeVtable scope } -- All this complexity is just so we can handle un-suffixed numeric @@ -449,24 +471,27 @@ patLitMkType (PatLitPrim v) _ = checkSizeExp' :: ExpBase NoInfo VName -> TermM Exp checkSizeExp' e = do e' <- checkExp e - ctEq (Reason (locOf e)) (expType e') (Scalar (Prim (Signed Int64))) + e_t <- expType e' + ctEq (Reason (locOf e)) e_t (Scalar (Prim (Signed Int64))) pure e' checkPat' :: PatBase NoInfo VName ParamType -> - Inferred ParamType -> + Inferred (TypeBase SComp Diet) -> TermM (Pat ParamType) checkPat' (PatParens p loc) t = PatParens <$> checkPat' p t <*> pure loc checkPat' (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' p t <*> pure loc -checkPat' (Id name NoInfo loc) (Ascribed t) = - pure $ Id name (Info t) loc +checkPat' (Id name NoInfo loc) (Ascribed t) = do + t' <- asStructType t + pure $ Id name (Info t') loc checkPat' (Id name NoInfo loc) NoneInferred = do t <- newType loc "t" Observe pure $ Id name (Info t) loc -checkPat' (Wildcard _ loc) (Ascribed t) = - pure $ Wildcard (Info t) loc +checkPat' (Wildcard _ loc) (Ascribed t) = do + t' <- asStructType t + pure $ Wildcard (Info t') loc checkPat' (Wildcard NoInfo loc) NoneInferred = do t <- newType loc "t" Observe pure $ Wildcard (Info t) loc @@ -477,9 +502,9 @@ checkPat' (TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_t <- replicateM (length ps) (newType loc "t" Observe) - ctEq (Reason (locOf loc)) (toType (Scalar (tupleRecord ps_t))) (toType t) - TuplePat <$> zipWithM checkPat' ps (map Ascribed ps_t) <*> pure loc + ps_tvs <- replicateM (length ps) (newTyVar loc "t") + ctEq (Reason (locOf loc)) (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) t + TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) _ @@ -498,9 +523,8 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs - ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) $ toType t - st <- asStructType loc $ Scalar (Record p_fs') - checkPat' p $ Ascribed $ toParam Observe st + ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t + checkPat' p $ Ascribed $ const Observe <$> Scalar (Record p_fs') where check t_fs = traverse (uncurry checkPat') $ @@ -514,11 +538,11 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do -- Uniqueness kung fu to make the Monoid(mempty) instance give what -- we expect. We should perhaps stop being so implicit. - st' <- asStructType loc $ toType $ resToParam st + st' <- asType $ resToParam st case maybe_outer_t of Ascribed outer_t -> do - ctEq (Reason (locOf loc)) (toType st') (toType outer_t) + ctEq (Reason (locOf loc)) st' outer_t PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -530,7 +554,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - ctEq (Reason (locOf loc)) (toType t') (toType t) + ctEq (Reason (locOf loc)) (toType t') t pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -547,28 +571,29 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) <+> pretty (length ts) <+> "arguments." ps' <- zipWithM checkPat' ps $ map Ascribed ts - pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc + cs' <- traverse (mapM (asStructType)) cs + pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do p_t <- newType (srclocOf p) "t" Observe checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' - ctEq (Reason (locOf loc)) t' (toType t) - t'' <- asStructType loc t' + ctEq (Reason (locOf loc)) t' t + t'' <- asStructType t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps t <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' - t' <- asStructType loc t + t' <- asStructType t pure $ PatConstr n (Info $ toParam Observe t') ps' loc checkPat :: PatBase NoInfo VName (TypeBase Size u) -> - Inferred StructType -> + Inferred Type -> (Pat ParamType -> TermM a) -> TermM a checkPat p t m = - m =<< checkPat' (fmap (toParam Observe) p) (fmap (toParam Observe) t) + m =<< checkPat' (fmap (toParam Observe) p) (fmap (fmap (const Observe)) t) -- | Bind @let@-bound sizes. This is usually followed by 'bindletPat' -- immediately afterwards. @@ -581,7 +606,7 @@ bindSizes sizes m = bind (map sizeWithType sizes) m bindLetPat :: PatBase NoInfo VName (TypeBase Size u) -> - StructType -> + Type -> (Pat ParamType -> TermM a) -> TermM a bindLetPat p t m = do @@ -628,11 +653,15 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> NE.NonEmpty (Shape Size, Type) -> TermM (StructType, NE.NonEmpty AutoMap) +checkApply :: + SrcLoc -> + Maybe (QualName VName) -> + (Shape Size, Type) -> + NE.NonEmpty (Shape Size, Type) -> + TermM (Type, NE.NonEmpty AutoMap) checkApply loc fname (fframe, ftype) args = do ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args - rt' <- asStructType loc rt - pure (rt', argts) + pure (rt, argts) where onArg (i, f_f, f_t) (argframe, argtype) = do (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) @@ -732,18 +761,19 @@ mustHaveFields loc t (f : fs) ve_t = do ctEq (Reason (locOf loc)) t rt checkCase :: - StructType -> + Type -> CaseBase NoInfo VName -> - TermM (CaseBase Info VName, StructType) + TermM (CaseBase Info VName, Type) checkCase mt (CasePat p e loc) = bindLetPat p mt $ \p' -> do e' <- checkExp e - pure (CasePat (fmap toStruct p') e' loc, typeOf e') + e_t <- expType e' + pure (CasePat (fmap toStruct p') e' loc, e_t) checkCases :: - StructType -> + Type -> NE.NonEmpty (CaseBase NoInfo VName) -> - TermM (NE.NonEmpty (CaseBase Info VName), StructType) + TermM (NE.NonEmpty (CaseBase Info VName), Type) checkCases mt rest_cs = case NE.uncons rest_cs of (c, Nothing) -> do @@ -752,7 +782,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq (Reason (locOf c)) (toType c_t) (toType cs_t) + ctEq (Reason (locOf c)) c_t cs_t pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -790,16 +820,18 @@ checkRetDecl :: checkRetDecl _ Nothing = pure Nothing checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (Reason (locOf body)) (expType body) (toType st) + body_t <- expType body + st' <- asType st + ctEq (Reason (locOf body)) body_t st' pure $ Just te' checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- checkExp (Var qn _ loc) = do - t <- lookupVar loc qn + t <- asStructType =<< lookupVar loc qn pure $ Var qn (Info t) loc checkExp (OpSection op _ loc) = do - ftype <- lookupVar loc op + ftype <- asStructType =<< lookupVar loc op pure $ OpSection op (Info ftype) loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg @@ -841,7 +873,9 @@ checkExp (ArrayLit es _ loc) = do et <- newType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e - ctEq (Reason (locOf loc)) (expType e') (toType et) + e_t <- expType e' + et' <- asType et + ctEq (Reason (locOf loc)) e_t et' pure e' let arr_t = arrayOf (Shape [sizeFromInteger (L.genericLength es) loc]) et pure $ ArrayLit es' (Info arr_t) loc @@ -856,7 +890,8 @@ checkExp (RecordLit fs loc) = errIfAlreadySet (baseName name) rloc t <- lift $ lookupVar rloc $ qualName name modify $ M.insert (baseName name) rloc - pure $ RecordFieldImplicit name (Info t) rloc + t' <- lift $ asStructType t + pure $ RecordFieldImplicit name (Info t') rloc errIfAlreadySet f rloc = do maybe_sloc <- gets $ M.lookup f @@ -880,7 +915,8 @@ checkExp (Assert e1 e2 NoInfo loc) = do -- checkExp (Constr name es NoInfo loc) = do es' <- mapM checkExp es - t <- newTypeWithConstr loc "t" NoUniqueness name $ map expType es' + es_ts <- mapM expType es' + t <- newTypeWithConstr loc "t" NoUniqueness name es_ts pure $ Constr name es' (Info t) loc -- checkExp (AppExp (Apply fe args loc) NoInfo) = do @@ -891,13 +927,15 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do args ( \(_, arg) -> do arg' <- checkExp arg - pure (arg', (frameOf arg', expType arg')) + arg_t <- expType arg' + pure (arg', (frameOf arg', arg_t)) ) - (rt, ams) <- checkApply loc fname (frameOf fe', expType fe') argts' + fe_t <- expType fe' + (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) argts' + rt' <- asStructType rt pure $ AppExp (Apply fe' (NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args') loc) $ - Info $ - AppRes rt [] + Info (AppRes rt' []) where fname = case fe of @@ -906,59 +944,71 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 + e1_t <- expType e1' e2' <- checkExp e2 + e2_t <- expType e2' + (rt, ams) <- checkApply loc (Just op) - (mempty, toType ftype) - ((frameOf e1', toType $ typeOf e1') NE.:| [(frameOf e2', toType $ typeOf e2')]) + (mempty, ftype) + ((frameOf e1', e1_t) NE.:| [(frameOf e2', e2_t)]) + rt' <- asStructType rt let (am1 NE.:| [am2]) = ams + ftype' <- asStructType ftype pure $ AppExp - (BinOp (op, oploc) (Info ftype) (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) - (Info (AppRes rt [])) + (BinOp (op, oploc) (Info ftype') (e1', Info (Nothing, am1)) (e2', Info (Nothing, am2)) loc) + (Info (AppRes rt' [])) -- checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e + e_t <- expType e' t2 <- newType loc "t" NoUniqueness - t2' <- asStructType loc t2 - let t1 = typeOf e' - f1 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((f1, toType t1) NE.:| [(mempty, t2)]) + t2' <- asStructType t2 + let f1 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((f1, e_t) NE.:| [(mempty, t2)]) + rt' <- asStructType rt let (am1 NE.:| _) = ams + t1 <- asStructType e_t + optype' <- asStructType optype pure $ OpSectionLeft op - (Info optype) + (Info optype') e' ( Info (Unnamed, toParam Observe t1, Nothing, am1), Info (Unnamed, toParam Observe t2') ) - (Info (RetType [] (rt `setUniqueness` Nonunique)), Info []) + (Info (RetType [] (rt' `setUniqueness` Nonunique)), Info []) loc checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e + e_t <- expType e' t1 <- newType loc "t" NoUniqueness - t1' <- asStructType loc t1 - let t2 = typeOf e' - f2 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, toType optype) ((mempty, t1) NE.:| [(f2, toType t2)]) + t1' <- asStructType t1 + let f2 = frameOf e' + (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((mempty, t1) NE.:| [(f2, e_t)]) + rt' <- asStructType rt let (_ NE.:| [am2]) = ams + t2 <- asStructType e_t + + optype' <- asStructType optype pure $ OpSectionRight op - (Info optype) + (Info optype') e' -- Dummy types. ( Info (Unnamed, toParam Observe t1'), Info (Unnamed, toParam Observe t2, Nothing, am2) ) - (Info $ RetType [] (rt `setUniqueness` Nonunique)) + (Info $ RetType [] (rt' `setUniqueness` Nonunique)) loc -- checkExp (ProjectSection fields NoInfo loc) = do @@ -971,70 +1021,91 @@ checkExp (ProjectSection fields NoInfo loc) = do checkExp (Lambda params body retdecl NoInfo loc) = do bindParams [] params $ \params' -> do body' <- checkExp body + body_t <- expType body' + + body_t' <- asStructType body_t retdecl' <- checkRetDecl body' retdecl - let ret = RetType [] $ toRes Nonunique $ typeOf body' + let ret = RetType [] $ toRes Nonunique body_t' pure $ Lambda params' body' retdecl' (Info ret) loc -- checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e + e_t <- expType e' - bindSizes sizes . incLevel . bindLetPat pat (typeOf e') $ \pat' -> do + bindSizes sizes . incLevel . bindLetPat pat e_t $ \pat' -> do body' <- incLevel $ checkExp body + body_t <- expType body' + + body_t' <- asStructType body_t pure $ AppExp (LetPat sizes (fmap toStruct pat') e' body' loc) - (Info $ AppRes (typeOf body') []) + (Info $ AppRes body_t' []) -- checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) = do (tparams', params', retdecl', rettype, e') <- bindParams tparams params $ \params' -> do e' <- checkExp e - let ret = RetType [] $ toRes Nonunique $ typeOf e' + e_t <- expType e' + let ret = fmap (const Nonunique) e_t retdecl' <- checkRetDecl e' retdecl pure (tparams, params', retdecl', ret, e') - let entry = BoundV tparams' $ funType params' rettype + params'' <- mapM (traverse asType) params' + + let entry = BoundV tparams' $ funType params'' $ RetType [] rettype bindF scope = scope { scopeVtable = M.insert name entry $ scopeVtable scope } body' <- localScope bindF $ checkExp body + body_t <- expType body' + body_t' <- asStructType body_t + rettype' <- asStructType rettype pure $ AppExp ( LetFun name - (tparams', params', retdecl', Info rettype, e') + (tparams', params', retdecl', Info (RetType [] rettype'), e') body' loc ) - (Info $ AppRes (typeOf body') []) + (Info $ AppRes body_t' []) -- checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- require "use in range expression" anyIntType =<< checkExp start let check e = do e' <- checkExp e - ctEq (Reason (locOf e')) (expType start') (expType e') + start_t <- expType start' + e_t <- expType e' + ctEq (Reason (locOf e')) start_t e_t pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end range_t <- newType loc "range" NoUniqueness - ctEq (Reason (locOf start')) (toType range_t) (arrayOfRank 1 (expType start')) + range_t' <- asType range_t + start_t <- expType start' + ctEq (Reason (locOf start')) range_t' (arrayOfRank 1 start_t) pure $ AppExp (Range start' maybe_step' end' loc) $ Info $ AppRes range_t [] -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e kt <- newType loc "kt" NoUniqueness t <- newTypeWithField loc "t" k kt - ctEq (Reason (locOf e')) (expType e') t - kt' <- asStructType loc kt + e_t <- expType e' + ctEq (Reason (locOf e')) e_t t + kt' <- asStructType kt pure $ Project k e' (Info kt') loc -- checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src + src_t <- expType src' ve' <- checkExp ve - mustHaveFields loc (expType src') fields (expType ve') - pure $ RecordUpdate src' fields ve' (Info (typeOf src')) loc + ve_t <- expType ve' + mustHaveFields loc src_t fields ve_t + src_t' <- asStructType src_t + pure $ RecordUpdate src' fields ve' (Info src_t') loc -- checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice @@ -1044,67 +1115,84 @@ checkExp (IndexSection slice NoInfo loc) = do let num_slices = length $ filter isSlice slice ctEq (Reason (locOf loc)) (toType index_arg_t) $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t - index_res_t' <- asStructType loc index_res_t + index_res_t' <- asStructType index_res_t let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' pure $ IndexSection slice' (Info ft) loc -- checkExp (AppExp (Index e slice loc) _) = do e' <- checkExp e + e_t <- expType e' slice' <- checkSlice slice - index_t <- newType loc "index" NoUniqueness + index_tv <- newTyVar loc "index" index_elem_t <- newType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (Reason (locOf loc)) (toType index_t) $ arrayOfRank num_slices index_elem_t - ctEq (Reason (locOf e')) (expType e') $ arrayOfRank (length slice) index_elem_t - pure $ AppExp (Index e' slice' loc) (Info $ AppRes index_t []) + ctEq (Reason (locOf loc)) (tyVarType NoUniqueness index_tv) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf e')) e_t $ arrayOfRank (length slice) index_elem_t + pure $ AppExp (Index e' slice' loc) (Info $ AppRes (tyVarType NoUniqueness index_tv) []) -- checkExp (Update src slice ve loc) = do src' <- checkExp src + src_t <- expType src' slice' <- checkSlice slice ve' <- checkExp ve + ve_t <- expType ve' let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (Reason (locOf src')) (expType src') $ arrayOfRank (length slice) update_elem_t - ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf src')) src_t $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc -- checkExp (AppExp (LetWith dest src slice ve body loc) _) = do src_t <- lookupVar (srclocOf src) $ qualName $ identName src - let src' = src {identType = Info src_t} - dest' = dest {identType = Info src_t} + src_t' <- asStructType src_t + let src' = src {identType = Info src_t'} + dest' = dest {identType = Info src_t'} slice' <- checkSlice slice ve' <- checkExp ve + ve_t <- expType ve' let num_slices = length $ filter isSlice slice update_elem_t <- newType loc "update_elem" NoUniqueness - ctEq (Reason (locOf loc)) (toType src_t) $ arrayOfRank (length slice) update_elem_t - ctEq (Reason (locOf ve')) (expType ve') $ arrayOfRank num_slices update_elem_t + ctEq (Reason (locOf loc)) src_t $ arrayOfRank (length slice) update_elem_t + ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t bind [dest'] $ do body' <- checkExp body - pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes (typeOf body') []) + body_t <- expType body' + body_t' <- asStructType body_t + pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t' []) -- checkExp (AppExp (If e1 e2 e3 loc) _) = do e1' <- checkExp e1 + e1_t <- expType e1' e2' <- checkExp e2 + e2_t <- expType e2' e3' <- checkExp e3 + e3_t <- expType e3' - ctEq (Reason (locOf e1')) (expType e1') (Scalar (Prim Bool)) - ctEq (Reason (locOf loc)) (expType e2') (expType e3') + ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) + ctEq (Reason (locOf loc)) e2_t e3_t - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes (typeOf e2') []) + e2_t' <- asStructType e2_t + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes e2_t' []) -- checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e - (cs', t) <- checkCases (typeOf e') cs - pure $ AppExp (Match e' cs' loc) (Info $ AppRes t []) + e_t <- expType e' + + (cs', t) <- checkCases e_t cs + t' <- asStructType t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t' []) -- checkExp (AppExp (Loop _ pat arg form body loc) _) = do arg' <- checkExp arg - bindLetPat pat (typeOf arg') $ \pat' -> do + arg_t <- expType arg' + bindLetPat pat arg_t $ \pat' -> do (form', body') <- case form of For (Ident i _ iloc) bound -> do bound' <- require "loop bound" anyIntType =<< checkExp bound - let i' = Ident i (Info (typeOf bound')) iloc + bound_t <- expType bound' + bound_t' <- asStructType bound_t + let i' = Ident i (Info bound_t') iloc bind [i'] $ do body' <- checkExp body pure (For i' bound', body') @@ -1115,11 +1203,14 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do ForIn elemp arr -> do arr' <- checkExp arr elem_t <- newType elemp "elem" NoUniqueness - ctEq (Reason (locOf arr')) (expType arr') $ arrayOfRank 1 (toType elem_t) - bindLetPat elemp elem_t $ \elemp' -> do + arr_t <- expType arr' + elem_t' <- asType elem_t + ctEq (Reason (locOf arr')) arr_t $ arrayOfRank 1 elem_t' + bindLetPat elemp elem_t' $ \elemp' -> do body' <- checkExp body pure (ForIn (toStruct <$> elemp') arr', body') - ctEq (Reason (locOf loc)) (expType arg') (expType body') + body_t <- expType body' + ctEq (Reason (locOf loc)) arg_t body_t pure $ AppExp (Loop [] pat' arg' form' body' loc) @@ -1128,12 +1219,16 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do checkExp (Ascript e te loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (Reason (locOf e')) (expType e') (toType st) + e_t <- expType e' + st' <- asType st + ctEq (Reason (locOf e')) e_t st' pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te - ctEq (Reason (locOf e')) (expType e') (toType st) + e_t <- expType e' + st' <- asType st + ctEq (Reason (locOf e')) e_t st' pure $ Coerce e' te' (Info (toStruct st)) loc doDefault :: @@ -1172,7 +1267,7 @@ doDefaults tyvars_at_toplevel substs = do pure $ M.map (substTyVars (`M.lookup` substs')) substs' generalise :: - StructType -> [VName] -> Solution -> ([TypeParam], [VName]) + TypeBase () NoUniqueness -> [VName] -> Solution -> ([TypeParam], [VName]) generalise fun_t unconstrained solution = -- Candidates for let-generalisation are those type variables that -- are used in fun_t. @@ -1190,7 +1285,7 @@ generalise fun_t unconstrained solution = generaliseAndDefaults :: [VName] -> Solution -> - StructType -> + TypeBase () NoUniqueness -> TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) generaliseAndDefaults unconstrained solution t = do let (generalised, unconstrained') = @@ -1225,8 +1320,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do pure (params', body', retdecl') cts <- gets termConstraints - tyvars <- gets termTyVars + artificial <- gets termArtificial debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" @@ -1237,14 +1332,20 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do "## body:", prettyString body', "## tyvars:", - unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars + unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars, + "## artificial:", + unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) ] - onRankSolution retdecl' =<< rankAnalysis1 loc cts tyvars params' body' + onRankSolution retdecl' + =<< rankAnalysis1 loc cts tyvars artificial params' body' where - onRankSolution retdecl' ((cts', tyvars'), params', body'') = do + onRankSolution retdecl' ((cts', tyvars'), artificial, params', body'') = do solution <- - bitraverse pure (onTySolution params' body'') $ solve cts' tyvars' + bitraverse + pure + (fmap (second (onArtificial artificial)) . onTySolution params' body'') + $ solve cts' tyvars' debugTraceM 3 $ unlines [ "## constraints:", @@ -1259,12 +1360,16 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do pure (solution, params', retdecl', body'') onTySolution params' body' (unconstrained, solution) = do + body_t <- expType body' let fun_t = foldFunType - (map patternType params') - (RetType [] $ toRes Nonunique (typeOf body')) + (map (first (const ()) . patternType) params') + (RetType [] $ bimap (const ()) (const Nonunique) body_t) generaliseAndDefaults unconstrained solution fun_t + onArtificial artificial solution = + M.map (substTyVars (`M.lookup` solution) . first (const ())) artificial <> solution + checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), Exp) @@ -1272,11 +1377,13 @@ checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars - ((cts', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars [] e' + artificial <- gets termArtificial + ((cts', tyvars'), _, _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' case solve cts' tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do - x <- generaliseAndDefaults unconstrained solution $ typeOf e'' + e_t <- expType e'' + x <- generaliseAndDefaults unconstrained solution $ first (const ()) e_t pure (Right x, e'') -- | Type-check a single size expression in isolation. This expression may @@ -1288,8 +1395,9 @@ checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints tyvars <- gets termTyVars + artificial <- gets termArtificial - (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars [] e' + (cts_tyvars', _, _, es') <- L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- forM cts_tyvars' $ From 2fbe6c22b6565074cf9e36dc1de0c5c0b5f1e13f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:30:20 +0100 Subject: [PATCH 194/296] Better to do this in same pass. --- src/Language/Futhark/TypeChecker/Rank.hs | 53 ++++++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 15 +++--- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index fba519b544..c74fad677c 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -218,13 +218,21 @@ solveRankILP loc prog = do : map prettyString (M.toList r) pure rs -rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m (([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp) +rankAnalysis1 :: + (MonadTypeChecker m) => + SrcLoc -> + [Ct] -> + TyVars -> + M.Map TyVar Type -> + [Pat ParamType] -> + Exp -> + m (([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp) rankAnalysis1 loc cs tyVars artificial params body = do solutions <- rankAnalysis loc cs tyVars artificial params body case solutions of [sol] -> pure sol sols -> do - let (_, _, _, bodies') = L.unzip4 sols + let (_, _, bodies') = unzip3 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -232,15 +240,23 @@ rankAnalysis1 loc cs tyVars artificial params body = do ] ++ map pretty bodies' -rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> [Ct] -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> Exp -> m [(([Ct], TyVars), M.Map TyVar Type, [Pat ParamType], Exp)] -rankAnalysis _ [] tyVars artificial params body = pure [(([], tyVars), artificial, params, body)] +rankAnalysis :: + (MonadTypeChecker m) => + SrcLoc -> + [Ct] -> + TyVars -> + M.Map TyVar Type -> + [Pat ParamType] -> + Exp -> + m [(([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp)] +rankAnalysis _ [] tyVars artificial params body = + pure [(([], artificial, tyVars), params, body)] rankAnalysis loc cs tyVars artificial params body = do rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) - cts_tyvars' <- mapM (substRankInfo cs tyVars) rank_maps + cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps - artificial' <- mapM (substRankInfoArtificial tyVars artificial) rank_maps - pure $ L.zip4 cts_tyvars' artificial' params' bodys + pure $ zip3 cts_tyvars' params' bodys where cs' = foldMap (splitFuncs . distribute) cs splitFuncs @@ -255,23 +271,22 @@ rankAnalysis loc cs tyVars artificial params body = do t2r' = t2r `setUniqueness` NoUniqueness splitFuncs c = [c] -substRankInfo :: (MonadTypeChecker m) => [Ct] -> TyVars -> Map VName Int -> m ([Ct], TyVars) -substRankInfo cs tyVars rankmap = do - (cs', new_cs, new_tyVars) <- +substRankInfo :: + (MonadTypeChecker m) => + [Ct] -> + M.Map VName Type -> + TyVars -> + Map VName Int -> + m ([Ct], M.Map VName Type, TyVars) +substRankInfo cs artificial tyVars rankmap = do + ((cs', artificial'), new_cs, new_tyVars) <- runSubstT tyVars rankmap $ - substRanks $ - filter (not . isCtAM) cs - pure (cs' <> new_cs, new_tyVars <> tyVars) + (,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial + pure (cs' <> new_cs, artificial', new_tyVars <> tyVars) where isCtAM (CtAM {}) = True isCtAM _ = False -substRankInfoArtificial :: (MonadTypeChecker m) => TyVars -> M.Map VName Type -> Map VName Int -> m (M.Map VName Type) -substRankInfoArtificial tyvars artificial rankmap = do - (artificial', _, _) <- - runSubstT tyvars rankmap $ traverse substRanks artificial - pure artificial' - runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do let env = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 6839cd3d29..371f139f99 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -524,7 +524,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) | otherwise = do p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t - checkPat' p $ Ascribed $ const Observe <$> Scalar (Record p_fs') + checkPat' p $ Ascribed $ Observe <$ Scalar (Record p_fs') where check t_fs = traverse (uncurry checkPat') $ @@ -571,7 +571,7 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) <+> pretty (length ts) <+> "arguments." ps' <- zipWithM checkPat' ps $ map Ascribed ts - cs' <- traverse (mapM (asStructType)) cs + cs' <- traverse (mapM asStructType) cs pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do @@ -1340,7 +1340,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do onRankSolution retdecl' =<< rankAnalysis1 loc cts tyvars artificial params' body' where - onRankSolution retdecl' ((cts', tyvars'), artificial, params', body'') = do + onRankSolution retdecl' ((cts', artificial, tyvars'), params', body'') = do solution <- bitraverse pure @@ -1378,7 +1378,8 @@ checkSingleExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars artificial <- gets termArtificial - ((cts', tyvars'), _, _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' + ((cts', artificial', tyvars'), _, e'') <- + rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' case solve cts' tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do @@ -1397,11 +1398,11 @@ checkSizeExp e = runTermM $ do tyvars <- gets termTyVars artificial <- gets termArtificial - (cts_tyvars', _, _, es') <- L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' + (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- - forM cts_tyvars' $ - bitraverse pure (traverse (doDefaults mempty)) . uncurry solve + forM cts_tyvars' $ \(cts', artificial', tyvars') -> + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' case (solutions, es') of ([solution], [e'']) -> From f16ee6aa371192a50adce27fe5f795e37b2131dc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:37:08 +0100 Subject: [PATCH 195/296] More cleanup. --- src/Language/Futhark/TypeChecker/Terms2.hs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 371f139f99..5917860e60 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1014,8 +1014,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do checkExp (ProjectSection fields NoInfo loc) = do a <- newType loc "a" NoUniqueness b <- newType loc "b" NoUniqueness - mustHaveFields loc (toType a) fields (toType b) - let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique + mustHaveFields loc a fields b + ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc -- checkExp (Lambda params body retdecl NoInfo loc) = do @@ -1047,9 +1047,8 @@ checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) bindParams tparams params $ \params' -> do e' <- checkExp e e_t <- expType e' - let ret = fmap (const Nonunique) e_t retdecl' <- checkRetDecl e' retdecl - pure (tparams, params', retdecl', ret, e') + pure (tparams, params', retdecl', fmap (const Nonunique) e_t, e') params'' <- mapM (traverse asType) params' @@ -1113,10 +1112,9 @@ checkExp (IndexSection slice NoInfo loc) = do index_elem_t <- newType loc "index_elem" NoUniqueness index_res_t <- newType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice - ctEq (Reason (locOf loc)) (toType index_arg_t) $ arrayOfRank num_slices index_elem_t + ctEq (Reason (locOf loc)) index_arg_t $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t - index_res_t' <- asStructType index_res_t - let ft = Scalar $ Arrow mempty Unnamed Observe index_arg_t $ toResRet Nonunique $ RetType [] index_res_t' + ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe index_arg_t $ second (const Nonunique) $ RetType [] index_res_t pure $ IndexSection slice' (Info ft) loc -- checkExp (AppExp (Index e slice loc) _) = do From 74be1732e22138842e86847a9a8f5b527e65318a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 12 Mar 2024 17:38:23 +0100 Subject: [PATCH 196/296] Consistency. --- src/Language/Futhark/TypeChecker/Terms2.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 5917860e60..aceabbbe22 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -888,10 +888,9 @@ checkExp (RecordLit fs loc) = RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc checkField (RecordFieldImplicit name NoInfo rloc) = do errIfAlreadySet (baseName name) rloc - t <- lift $ lookupVar rloc $ qualName name + t <- lift $ asStructType =<< lookupVar rloc (qualName name) modify $ M.insert (baseName name) rloc - t' <- lift $ asStructType t - pure $ RecordFieldImplicit name (Info t') rloc + pure $ RecordFieldImplicit name (Info t) rloc errIfAlreadySet f rloc = do maybe_sloc <- gets $ M.lookup f From 64f6a86174e995e58552a1e01ef6efed73b5d276 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 19:41:27 -0700 Subject: [PATCH 197/296] Fix crashing when LP objective is a constant. This fix is jank(ish). --- src/Futhark/Solve/GLPK.hs | 10 +++++++--- src/Futhark/Solve/LP.hs | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index b2d340d683..b4b0b4602b 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -48,8 +48,12 @@ glpk lp = do res glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) -glpk' lp = do - (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp - pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres +glpk' lp + | F.isConstant (F.objective lp) -- FIXME + = + pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) + | otherwise = do + (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp + pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres where opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index f1b7d18939..47804b738b 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -28,6 +28,7 @@ module Futhark.Solve.LP (~<=~), (~>=~), rowEchelonLPE, + isConstant, ) where @@ -109,6 +110,9 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where ) $ Map.toList m +isConstant :: (Ord v) => LSum v a -> Bool +isConstant (LSum m) = Map.keysSet m `S.isSubsetOf` S.singleton Nothing + instance Functor (LSum v) where fmap f (LSum m) = LSum $ fmap f m From 5666c2b20ade746507e5bdbb83e5e333019d37b7 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 19:58:39 -0700 Subject: [PATCH 198/296] Fix tuples/records in rank analysis. --- src/Language/Futhark/TypeChecker/Rank.hs | 94 ++++++++++++++++-------- 1 file changed, 64 insertions(+), 30 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c74fad677c..2f2bbcda6b 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -57,23 +57,63 @@ instance Rank Type where rank (Scalar t) = rank t rank (Array _ shape t) = rank shape ~+~ rank t -class Distribute a where - distribute :: a -> a - -instance Distribute (TypeBase dim u) where - distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ - Arrow - u - Unnamed - mempty - (arrayOf s ta) - (RetType rd $ distribute (arrayOfWithAliases Nonunique s tr)) - distribute t = t - -instance Distribute Ct where - distribute (CtEq r t1 t2) = CtEq r (distribute t1) (distribute t2) - distribute c = c +distribAndSplitArrows :: Ct -> [Ct] +distribAndSplitArrows (CtEq r t1 t2) = + splitArrows $ CtEq r (distribute t1) (distribute t2) + where + distribute :: TypeBase dim as -> TypeBase dim as + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute $ arrayOfWithAliases Nonunique s tr) + distribute t = t + + splitArrows + ( CtEq + reason + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) + ) = + splitArrows (CtEq reason t1a t2a) ++ splitArrows (CtEq reason t1r' t2r') + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness + splitArrows c = [c] +distribAndSplitArrows ct = [ct] + +distribAndSplitCnstrs :: Ct -> [Ct] +distribAndSplitCnstrs ct@(CtEq r t1 t2) = + ct : splitCnstrs (CtEq r (distribute1 t1) (distribute1 t2)) + where + distribute1 :: TypeBase dim as -> TypeBase dim as + distribute1 (Array u s (Record ts1)) = + Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 + distribute1 t = t + + splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = + concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) + splitCnstrs c = [] +distribAndSplitCnstrs ct = [ct] + +distributeOverCnstrs :: Ct -> [Ct] +distributeOverCnstrs ct@(CtEq r t1 t2) = + [ct, CtEq r t1' t2'] + where + -- case (t1', t2') of + -- (Nothing, Nothing) -> [ct] + -- _ -> [ct, CtEq r (fromMaybe t1 t1') (fromMaybe t2 t2')] + + distribute :: TypeBase dim as -> TypeBase dim as + distribute (Array u s (Record ts1)) = + Scalar $ Record $ fmap (distribute . arrayOfWithAliases u s) ts1 + distribute t = t + t1' = distribute t1 + t2' = distribute t2 +distributeOverCnstrs c = [c] data RankState = RankState { rankBinVars :: Map VName VName, @@ -258,18 +298,9 @@ rankAnalysis loc cs tyVars artificial params body = do params' = map ((`map` params) . updAMPat) rank_maps pure $ zip3 cts_tyvars' params' bodys where - cs' = foldMap (splitFuncs . distribute) cs - splitFuncs - ( CtEq - reason - (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) - (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) - ) = - splitFuncs (CtEq reason t1a t2a) ++ splitFuncs (CtEq reason t1r' t2r') - where - t1r' = t1r `setUniqueness` NoUniqueness - t2r' = t2r `setUniqueness` NoUniqueness - splitFuncs c = [c] + cs' = + foldMap distribAndSplitCnstrs $ + foldMap distribAndSplitArrows cs substRankInfo :: (MonadTypeChecker m) => @@ -331,7 +362,7 @@ newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do t' <- lift $ newTypeName (baseName t) shape <- rankToShape t - loc <- (locOf . snd . fromJust . (M.!? t)) <$> asks envTyVars + loc <- asks ((locOf . snd . fromJust . (M.!? t)) . envTyVars) modify $ \s -> s { substNewVars = M.insert t t' $ substNewVars s, @@ -399,6 +430,9 @@ instance SubstRanks (TypeBase SComp u) where shape' <- substRanks shape t' <- substRanks $ Scalar t pure $ arrayOfWithAliases u shape' t' + substRanks (Scalar (Record fs)) = do + fs' <- mapM substRanks fs + pure $ Scalar $ Record fs' substRanks t = pure t instance SubstRanks Ct where From 7039dcc5603a1e820ddbcefda8cbc2074457eea2 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 21:05:59 -0700 Subject: [PATCH 199/296] Delete this. --- src/Language/Futhark/TypeChecker/Rank.hs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 2f2bbcda6b..939b1dfb9f 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -99,22 +99,6 @@ distribAndSplitCnstrs ct@(CtEq r t1 t2) = splitCnstrs c = [] distribAndSplitCnstrs ct = [ct] -distributeOverCnstrs :: Ct -> [Ct] -distributeOverCnstrs ct@(CtEq r t1 t2) = - [ct, CtEq r t1' t2'] - where - -- case (t1', t2') of - -- (Nothing, Nothing) -> [ct] - -- _ -> [ct, CtEq r (fromMaybe t1 t1') (fromMaybe t2 t2')] - - distribute :: TypeBase dim as -> TypeBase dim as - distribute (Array u s (Record ts1)) = - Scalar $ Record $ fmap (distribute . arrayOfWithAliases u s) ts1 - distribute t = t - t1' = distribute t1 - t2' = distribute t2 -distributeOverCnstrs c = [c] - data RankState = RankState { rankBinVars :: Map VName VName, rankCounter :: !Int, From e5b4f10018589e2d47dce1fb5d3e3b918e950140 Mon Sep 17 00:00:00 2001 From: zfnmxt Date: Tue, 12 Mar 2024 21:45:45 -0700 Subject: [PATCH 200/296] Add sum type support. --- src/Language/Futhark/TypeChecker/Rank.hs | 41 +++++++++++++++++++----- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 939b1dfb9f..42c440dd26 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -92,11 +92,16 @@ distribAndSplitCnstrs ct@(CtEq r t1 t2) = distribute1 :: TypeBase dim as -> TypeBase dim as distribute1 (Array u s (Record ts1)) = Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 + distribute1 (Array u s (Sum cs)) = + Scalar $ Sum $ (fmap . fmap) (arrayOfWithAliases u s) cs distribute1 t = t + -- FIXME. Should check for key set equality here. splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) - splitCnstrs c = [] + splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) = + concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2) + splitCnstrs _ = [] distribAndSplitCnstrs ct = [ct] data RankState = RankState @@ -276,6 +281,14 @@ rankAnalysis :: rankAnalysis _ [] tyVars artificial params body = pure [(([], artificial, tyVars), params, body)] rankAnalysis loc cs tyVars artificial params body = do + debugTraceM 3 $ + unlines $ + [ "##rankAnalysis", + "cs:", + unlines $ map prettyString cs, + "cs':", + unlines $ map prettyString cs' + ] rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps @@ -294,10 +307,10 @@ substRankInfo :: Map VName Int -> m ([Ct], M.Map VName Type, TyVars) substRankInfo cs artificial tyVars rankmap = do - ((cs', artificial'), new_cs, new_tyVars) <- + ((cs', artificial', tyVars'), new_cs, new_tyVars) <- runSubstT tyVars rankmap $ - (,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial - pure (cs' <> new_cs, artificial', new_tyVars <> tyVars) + (,,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial <*> traverse substRanks tyVars + pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') where isCtAM (CtAM {}) = True isCtAM _ = False @@ -409,20 +422,32 @@ instance SubstRanks (TypeBase SComp u) where ta' <- substRanks ta tr' <- substRanks tr pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) - substRanks (Scalar (Record fs)) = Scalar . Record <$> traverse substRanks fs + substRanks (Scalar (Record fs)) = + Scalar . Record <$> traverse substRanks fs + substRanks (Scalar (Sum cs)) = + Scalar . Sum <$> (traverse . traverse) substRanks cs substRanks (Array u shape t) = do shape' <- substRanks shape t' <- substRanks $ Scalar t pure $ arrayOfWithAliases u shape' t' - substRanks (Scalar (Record fs)) = do - fs' <- mapM substRanks fs - pure $ Scalar $ Record fs' substRanks t = pure t instance SubstRanks Ct where substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 substRanks _ = error "" +instance SubstRanks TyVarInfo where + substRanks tv@TyVarFree {} = pure tv + substRanks tv@TyVarPrim {} = pure tv + substRanks (TyVarRecord loc fs) = + TyVarRecord loc <$> traverse substRanks fs + substRanks (TyVarSum loc cs) = + TyVarSum loc <$> (traverse . traverse) substRanks cs + substRanks tv@TyVarEql {} = pure tv + +instance SubstRanks (Int, TyVarInfo) where + substRanks (lvl, tv) = (lvl,) <$> substRanks tv + updAM :: Map VName Int -> Exp -> Exp updAM rank_map e = case e of From 62ddb1d194a6f0a54c2c7a633b98740c63d9cfd9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Mar 2024 12:47:47 +0100 Subject: [PATCH 201/296] Respect return type annotations. --- src/Language/Futhark/TypeChecker/Terms2.hs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index aceabbbe22..cfa88069ca 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -816,14 +816,14 @@ instance Pretty (Unmatched (Pat StructType)) where checkRetDecl :: Exp -> Maybe (TypeExp (ExpBase NoInfo VName) VName) -> - TermM (Maybe (TypeExp Exp VName)) -checkRetDecl _ Nothing = pure Nothing + TermM (Type, Maybe (TypeExp Exp VName)) +checkRetDecl body Nothing = (,Nothing) <$> expType body checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te body_t <- expType body st' <- asType st ctEq (Reason (locOf body)) body_t st' - pure $ Just te' + pure (second (const NoUniqueness) st', Just te') checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- @@ -1020,10 +1020,9 @@ checkExp (ProjectSection fields NoInfo loc) = do checkExp (Lambda params body retdecl NoInfo loc) = do bindParams [] params $ \params' -> do body' <- checkExp body - body_t <- expType body' + (body_t, retdecl') <- checkRetDecl body' retdecl body_t' <- asStructType body_t - retdecl' <- checkRetDecl body' retdecl let ret = RetType [] $ toRes Nonunique body_t' pure $ Lambda params' body' retdecl' (Info ret) loc -- @@ -1045,8 +1044,7 @@ checkExp (AppExp (LetFun name (tparams, params, retdecl, NoInfo, e) body loc) _) (tparams', params', retdecl', rettype, e') <- bindParams tparams params $ \params' -> do e' <- checkExp e - e_t <- expType e' - retdecl' <- checkRetDecl e' retdecl + (e_t, retdecl') <- checkRetDecl e' retdecl pure (tparams, params', retdecl', fmap (const Nonunique) e_t, e') params'' <- mapM (traverse asType) params' @@ -1313,7 +1311,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do (params', body', retdecl') <- bindParams tparams params $ \params' -> do body' <- checkExp body - retdecl' <- checkRetDecl body' retdecl + (_, retdecl') <- checkRetDecl body' retdecl pure (params', body', retdecl') cts <- gets termConstraints From fc5f05dc2f33b6beb8934c8cdfd684e5654b7787 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Mar 2024 15:58:23 +0100 Subject: [PATCH 202/296] Lovely code. --- tests/automap/mri-q.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut index f53b5df7a6..270e18195a 100644 --- a/tests/automap/mri-q.fut +++ b/tests/automap/mri-q.fut @@ -38,4 +38,4 @@ entry main [numK][numX] (phiR: [numK]f32) (phiI: [numK]f32) = let (qr, qi) = main_orig kx ky kz x y z phiR phiI let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI - in and (map2 (==) qr qr_am && map2 (==) qi qi_am) + in and (qr == qr_am && qi == qi_am) From 6f3e32d4ccdb77c84bddfee32f42b743f523835c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Mar 2024 11:03:33 +0100 Subject: [PATCH 203/296] Update type annotation here. --- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 2a648ea8ab..e365cd97fc 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -176,8 +176,9 @@ checkPat' sizes (PatAscription p t loc) maybe_outer_t = do <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc -checkPat' _ (PatLit l info loc) _ = - pure $ PatLit l info loc +checkPat' _ (PatLit l (Info t) loc) _ = do + t' <- replaceTyVars loc t + pure $ PatLit l (Info t') loc checkPat' sizes (PatConstr n info ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps pure $ PatConstr n info ps' loc From bf64e471a8e5cbe36138c8deacb77d9679cfa70d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Mar 2024 11:53:24 +0100 Subject: [PATCH 204/296] Handle equality case too. --- src/Language/Futhark/TypeChecker/Constraints.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 4a7d30ae98..8c6c3ca3ef 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -197,7 +197,9 @@ linkTyVar reason v t = do modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} tyvars' <- case (M.lookup v tyvars, M.lookup t tyvars) of - (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl (TyVarFree _))) -> + (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree {})) -> + pure $ M.insert t (TyVarUnsol lvl info) tyvars + (Just (TyVarUnsol _ info@TyVarPrim {}), Just (TyVarUnsol lvl TyVarEql {})) -> pure $ M.insert t (TyVarUnsol lvl info) tyvars -- TODO: handle more cases. _ -> pure tyvars From 5941f88bf68ea4aa0f2cd15c5729ad116881689b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Mar 2024 17:05:16 +0100 Subject: [PATCH 205/296] Do not impose inferred type on polymorphic functions. This is necessary to properly handle polymorphic higher order functions that are passed functions with existential return sizes. --- .../Futhark/TypeChecker/Terms/Monad.hs | 27 ++++++++++--------- src/Language/Futhark/TypeChecker/Unify.hs | 4 +-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 42c5d53ab2..b8371cbfe8 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -440,31 +440,32 @@ instTyVars loc names orig_t1 orig_t2 = do evalStateT (f orig_t1 orig_t2) mempty --- | Instantiate a type scheme with fresh size variables for its size --- parameters. Replaces type parameters with their known --- instantiations. Returns the names of the fresh size variables and --- the instantiated type. +-- | Instantiate a type scheme with fresh variables for its size and +-- type parameters. Returns the names of the fresh size and type +-- variables and the instantiated type. instTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> StructType -> - TypeBase () NoUniqueness -> TermTypeM ([VName], StructType) -instTypeScheme qn loc tparams scheme_t inferred = do - (names, substs) <- fmap (unzip . catMaybes) . forM tparams $ \tparam -> do +instTypeScheme qn loc tparams scheme_t = do + (names, substs) <- fmap unzip . forM tparams $ \tparam -> do case tparam of - TypeParamType {} -> pure Nothing + TypeParamType l v _ -> do + i <- incCounter + v' <- newID $ mkTypeVarName (baseName v) i + constrain v' . NoConstraint l . mkUsage loc . docText $ + "instantiated type parameter of " <> dquotes (pretty qn) + pure (v', (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v') [])) TypeParamDim v _ -> do i <- incCounter v' <- newID $ mkTypeVarName (baseName v) i constrain v' . Size Nothing . mkUsage loc . docText $ "instantiated size parameter of " <> dquotes (pretty qn) - pure $ Just (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) + pure (v', (v, ExpSubst $ sizeFromName (qualName v') loc)) - let tp_names = map typeParamName $ filter isTypeParam tparams - t' <- instTyVars loc tp_names inferred $ applySubst (`lookup` substs) scheme_t - pure (names, t') + pure (names, applySubst (`lookup` substs) scheme_t) lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) @@ -541,7 +542,7 @@ lookupVar loc qn@(QualName qs name) inst_t = do if null tparams && null qs then pure bound_t else do - (tnames, t) <- instTypeScheme qn loc tparams bound_t $ first (const ()) inst_t + (tnames, t) <- instTypeScheme qn loc tparams bound_t outer_env <- asks termOuterEnv pure $ qualifyTypeVars outer_env tnames qs t Just EqualityF -> diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 4493b02b2d..259f62ed0e 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -230,7 +230,7 @@ prettySource ctx loc RigidCoerce = <+> pretty (locStrRel ctx loc) <> "." prettySource _ _ RigidUnify = - "is an artificial size invented during unification of functions with anonymous sizes." + textwrap "is an artificial size invented during unification of functions with anonymous sizes." prettySource ctx loc (RigidCond t1 t2) = "is unknown due to conditional expression at " <> pretty (locStrRel ctx loc) @@ -514,7 +514,7 @@ unifySizes usage bcs bound nonrigid e1 (Var v2 _ _) not (anyBound bound e1) || (qualLeaf v2 `elem` bound) = linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 unifySizes usage bcs _ _ e1 e2 = do - notes <- (<>) <$> dimNotes usage e2 <*> dimNotes usage e2 + notes <- (<>) <$> dimNotes usage e1 <*> dimNotes usage e2 unifyError usage notes bcs $ "Sizes" <+> dquotes (pretty e1) From 64637d06eadd31ff898eeb428e78a2385d76a8a8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 18 Mar 2024 09:59:22 +0100 Subject: [PATCH 206/296] Fix implicit record fields. --- src/Language/Futhark/TypeChecker/Terms.hs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b3afad3954..081e9c0702 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -343,6 +343,9 @@ unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp checkExp :: Exp -> TermTypeM Exp +checkExp (Var qn (Info t) loc) = do + t' <- lookupVar loc qn t + pure $ Var qn (Info t') loc checkExp (Literal val loc) = pure $ Literal val loc checkExp (Hole (Info t) loc) = do @@ -363,8 +366,9 @@ checkExp (RecordLit fs loc) = where checkField (RecordFieldExplicit f e rloc) = RecordFieldExplicit f <$> checkExp e <*> pure rloc - checkField (RecordFieldImplicit name (Info t) rloc) = - RecordFieldImplicit name <$> (Info <$> replaceTyVars rloc t) <*> pure rloc + checkField (RecordFieldImplicit name (Info t) rloc) = do + t' <- lookupVar rloc (qualName name) t + pure $ RecordFieldImplicit name (Info t') rloc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use @@ -540,9 +544,6 @@ checkExp (QualParens (modname, modnameloc) e loc) = do ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." -checkExp (Var qn (Info t) loc) = do - t' <- lookupVar loc qn t - pure $ Var qn (Info t') loc checkExp (Negate arg loc) = do arg' <- checkExp arg pure $ Negate arg' loc From 616a6122a69ab2940817bf48009840c580f00bc1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 18 Mar 2024 11:42:19 +0100 Subject: [PATCH 207/296] Fix type checking of project sections. --- src/Language/Futhark/Prop.hs | 9 +++++++++ src/Language/Futhark/TypeChecker/Terms.hs | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 0507139fdc..4a339ee124 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -57,6 +57,7 @@ module Language.Futhark.Prop foldFunType, typeVars, isAccType, + recordField, -- * Operations on types peelArray, @@ -251,6 +252,14 @@ diet (Array d _ _) = d diet (Scalar (TypeVar d _ _)) = d diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs +-- | Look up this record field if it exists. +recordField :: [Name] -> TypeBase dim u -> Maybe (TypeBase dim u) +recordField [] t = Just t +recordField (f : fs) (Scalar (Record fts)) + | Just ft <- M.lookup f fts = + recordField fs ft +recordField _ _ = Nothing + -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. toStructural :: diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 081e9c0702..75dd2947ea 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -770,6 +770,11 @@ checkExp (OpSectionRight op (Info op_t) e (_, Info (_, _, _, am)) _ loc) = do "Operator section with invalid operator of type" <+> pretty ftype checkExp (ProjectSection fields (Info t) loc) = do t' <- replaceTyVars loc t + case t' of + Scalar (Arrow _ _ _ t'' (RetType _ rt)) + | Just ft <- recordField fields t'' -> + unify (mkUsage loc "result of projection") ft $ toStruct rt + _ -> error $ "checkExp ProjectSection: " <> show t' pure $ ProjectSection fields (Info t') loc checkExp (IndexSection slice _ loc) = do slice' <- checkSlice slice From ac0472881efd8a3251a6b859b767337f17e60d70 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 27 Mar 2024 13:31:34 +0100 Subject: [PATCH 208/296] Rework type constraint solving. --- .../Futhark/TypeChecker/Constraints.hs | 291 ++++++++++++++---- 1 file changed, 232 insertions(+), 59 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 8c6c3ca3ef..e6759a9ff7 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -17,6 +17,7 @@ where import Control.Monad.Except import Control.Monad.State import Data.Bifunctor +import Data.List qualified as L import Data.Loc import Data.Map qualified as M import Data.Maybe @@ -139,14 +140,14 @@ substTyVar m v = Just (TyVarUnsol {}) -> Nothing Nothing -> Nothing -lookupTyVar :: TyVar -> SolveM (Maybe Type) +lookupTyVar :: TyVar -> SolveM (Int, Either TyVarInfo Type) lookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of Nothing -> error $ "Unknown tyvar: " <> prettyNameString v - Just (TyVarSol _ t) -> pure $ Just t + Just (TyVarSol lvl t) -> pure (lvl, Right t) Just (TyVarLink v') -> f v' - Just (TyVarUnsol {}) -> pure Nothing + Just (TyVarUnsol lvl info) -> pure (lvl, Left info) f orig -- | A solution maps a type variable to its substitution. This @@ -185,25 +186,229 @@ occursCheck reason v tp = do <+> pretty tp <> "." +unifySharedConstructors :: + Reason -> + M.Map Name [Type] -> + M.Map Name [Type] -> + SolveM () +unifySharedConstructors reason cs1 cs2 = + forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> + if length ts1 == length ts2 + then zipWithM (solveEq reason) ts1 ts2 + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructor" + indent 2 (pretty (Sum (M.singleton c ts1))) + "with type of constructor" + indent 2 (pretty (Sum (M.singleton c ts2))) + "because they differ in arity." + +unifySharedFields :: + Reason -> + M.Map Name Type -> + M.Map Name Type -> + SolveM () +unifySharedFields reason fs1 fs2 = + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_f, (ts1, ts2)) -> + solveEq reason ts1 ts2 + +mustSupportEql :: Reason -> Type -> SolveM () +mustSupportEql reason t = pure () + +-- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v lvl t = do occursCheck reason v t - modify $ \s -> s {solverTyVars = M.insert v (TyVarSol lvl t) $ solverTyVars s} - + v_info <- gets $ M.lookup v . solverTyVars + case (v_info, t) of + (Just (TyVarUnsol _ TyVarFree {}), _) -> + pure () + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + _ + ) -> + if t `elem` map (Scalar . Prim) v_pts + then pure () + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with" + indent 2 (pretty t) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Scalar (Sum cs2) + ) -> + if all (`elem` M.keys cs2) (M.keys cs1) + then unifySharedConstructors reason cs1 cs2 + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Sum cs2)) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + _ + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty t) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Scalar (Record fs2) + ) -> + if all (`elem` M.keys fs2) (M.keys fs1) + then unifySharedFields reason fs1 fs2 + else + throwError . TypeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with record type" + indent 2 (pretty (Record fs2)) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + _ + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty t) + (Just (TyVarUnsol _ (TyVarEql _)), _) -> + mustSupportEql reason t + -- + -- Internal error cases + (Just TyVarSol {}, _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just TyVarLink {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "linkTyVar: Nothing v: " <> prettyNameString v + + setInfo v (TyVarSol lvl t) + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v info $ solverTyVars s} + +-- Precondition: 'v' is currently flexible and 't' has no solution. linkTyVar :: Reason -> VName -> VName -> SolveM () linkTyVar reason v t = do occursCheck reason v $ Scalar $ TypeVar NoUniqueness (qualName t) [] - tyvars <- gets solverTyVars - modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) $ solverTyVars s} - tyvars' <- - case (M.lookup v tyvars, M.lookup t tyvars) of - (Just (TyVarUnsol _ info), Just (TyVarUnsol lvl TyVarFree {})) -> - pure $ M.insert t (TyVarUnsol lvl info) tyvars - (Just (TyVarUnsol _ info@TyVarPrim {}), Just (TyVarUnsol lvl TyVarEql {})) -> - pure $ M.insert t (TyVarUnsol lvl info) tyvars - -- TODO: handle more cases. - _ -> pure tyvars - modify $ \s -> s {solverTyVars = M.insert v (TyVarLink t) tyvars'} + v_info <- gets $ M.lookup v . solverTyVars + (lvl, t') <- lookupTyVar t + case (v_info, t') of + -- When either is completely unconstrained. + (Just (TyVarUnsol _ TyVarFree {}), _) -> + pure () + ( Just (TyVarUnsol _ info), + Left (TyVarFree {}) + ) -> + setInfo t (TyVarUnsol lvl info) + -- + -- TyVarPrim cases + ( Just (TyVarUnsol _ info@TyVarPrim {}), + Left TyVarEql {} + ) -> + setInfo t (TyVarUnsol lvl info) + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + Left (TyVarPrim t_loc t_pts) + ) -> + let pts = L.intersect v_pts t_pts + in if null pts + then + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be one of" + indent 2 (pretty t_pts) + else setInfo t (TyVarUnsol lvl (TyVarPrim t_loc pts)) + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + Left TyVarRecord {} + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be record." + ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + Left TyVarSum {} + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be sum." + -- + -- TyVarSum cases + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Left (TyVarSum loc cs2) + ) -> do + unifySharedConstructors reason cs1 cs2 + let cs3 = cs1 <> cs2 + setInfo t (TyVarUnsol lvl (TyVarSum loc cs3)) + ( Just (TyVarUnsol _ TyVarSum {}), + Left (TyVarPrim _ pts) + ) -> + throwError . TypeError (locOf reason) mempty $ + "A sum type cannot be one of" + indent 2 (pretty pts) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Left (TyVarRecord _ fs) + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Scalar (Record fs))) + ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + Left (TyVarEql _) + ) -> + mapM_ (mapM_ (mustSupportEql reason)) cs1 + -- + -- TyVarRecord cases + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Left (TyVarRecord loc fs2) + ) -> do + unifySharedFields reason fs1 fs2 + let fs3 = fs1 <> fs2 + setInfo t (TyVarUnsol lvl (TyVarRecord loc fs3)) + ( Just (TyVarUnsol _ TyVarRecord {}), + Left (TyVarPrim _ pts) + ) -> + throwError . TypeError (locOf reason) mempty $ + "A record type cannot be one of" + indent 2 (pretty pts) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Left (TyVarSum _ cs) + ) -> + throwError . TypeError (locOf reason) mempty $ + "Cannot unify record type" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty (Scalar (Sum cs))) + ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + Left (TyVarEql _) + ) -> + mapM_ (mustSupportEql reason) fs1 + -- + -- TyVarEql cases + (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarPrim {}) -> + pure () + (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarEql {}) -> + pure () + (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarRecord _ fs)) -> + mustSupportEql reason $ Scalar $ Record fs + (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarSum _ cs)) -> + mustSupportEql reason $ Scalar $ Sum cs + -- + -- Internal error cases + (Just TyVarSol {}, _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just TyVarLink {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "linkTyVar: Nothing v: " <> prettyNameString v + (_, Right t'') -> + error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' + + -- Finally insert the actual link. + setInfo v (TyVarLink t) -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -293,60 +498,28 @@ solveCt ct = CtAM {} -> pure () -- Good vibes only. solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () -solveTyVar (_, (_, TyVarFree {})) = pure () -solveTyVar (tv, (_, TyVarPrim loc pts)) = do - tv_t <- lookupTyVar tv - case tv_t of - Nothing -> pure () - Just t' - | t' `elem` map (Scalar . Prim) pts -> pure () - | otherwise -> - throwError . TypeError loc mempty $ - "Type must be one of" - indent 2 (pretty pts) - "but inferred to be" - indent 2 (pretty t') solveTyVar (tv, (_, TyVarRecord loc fs1)) = do - tv_t <- lookupTyVar tv + (_, tv_t) <- lookupTyVar tv case tv_t of - Nothing -> + Left _ -> throwError . TypeError loc mempty $ - "Type is ambiguous." + "Type" + <+> prettyName tv + <+> "is ambiguous." "Must be a record with fields" indent 2 (pretty (Scalar (Record fs1))) - Just (Scalar (Record fs2)) - | all (`M.member` fs2) (M.keys fs1) -> - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_k, (t1, t2)) -> - solveCt $ CtEq (Reason loc) t1 t2 - Just tv_t' -> - throwError . TypeError loc mempty $ - "Type must be record with fields" - indent 2 (pretty (Scalar (Record fs1))) - "but inferred to be" - indent 2 (pretty tv_t') + Right _ -> + pure () solveTyVar (tv, (_, TyVarSum loc cs1)) = do - tv_t <- lookupTyVar tv + (_, tv_t) <- lookupTyVar tv case tv_t of - Nothing -> + Left _ -> throwError . TypeError loc mempty $ "Type is ambiguous." "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) - Just (Scalar (Sum cs2)) - | all (`M.member` cs2) (M.keys cs1), - cs3 <- M.toList $ M.intersectionWith (,) cs1 cs2, - all (sameLength . snd) cs3 -> - forM_ cs3 $ \(_k, (t1s, t2s)) -> - mapM_ solveCt $ zipWith (CtEq (Reason loc)) t1s t2s - Just tv_t' -> - throwError . TypeError loc mempty $ - "Type must be sum type with constructors" - indent 2 (pretty (Scalar (Sum cs1))) - "but inferred to be" - indent 2 (pretty tv_t') - where - sameLength (x, y) = length x == length y -solveTyVar (_, (_, TyVarEql _)) = + Right _ -> pure () +solveTyVar (_, _) = pure () solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) From 2bdca0efcde1e906e86857d10619082bc96b99e3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 Mar 2024 14:15:21 +0100 Subject: [PATCH 209/296] Substitute dependent sizes when expanding automaps. --- src/Futhark/Internalise/Defunctionalise.hs | 6 +++--- src/Futhark/Internalise/FullNormalise.hs | 10 ++++++++-- src/Language/Futhark/Interpreter.hs | 2 +- src/Language/Futhark/Prop.hs | 17 ++++++++++------- src/Language/Futhark/TypeChecker.hs | 2 +- src/Language/Futhark/TypeChecker/Consumption.hs | 2 +- src/Language/Futhark/TypeChecker/Types.hs | 2 +- 7 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 82cc845d69..783c374409 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -851,7 +851,7 @@ unRetType (RetType ext t) = do defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal) defuncApplyFunction e@(Var qn (Info t) loc) num_args = do - let (argtypes, rettype) = unfoldFunType t + let (argtypes, rettype) = first (map snd) $ unfoldFunType t sv <- lookupVar (toStruct t) (qualLeaf qn) case sv of @@ -1001,8 +1001,8 @@ defuncApply f args appres loc = do (argtypes, _) = unfoldFunType $ typeOf f fmap (first $ updateReturn appres) $ foldM (defuncApplyArg fname) (f', f_sv) $ - NE.zip args $ - NE.tails argtypes + NE.zip args . NE.tails . map snd $ + argtypes where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 789e0d3c85..2fe87f6244 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -499,7 +499,7 @@ transformProg = mapM transformValBind -- | Expands 'AutoMap' annotations into explicit @map@s and @replicates@. expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp -expandAMAnnotations e = do +expandAMAnnotations e = case e of (AppExp (Apply f args loc) (Info res)) -> do let ((exts, ams), arg_es) = first unzip $ unzip $ map (first unInfo) $ NE.toList args @@ -511,7 +511,9 @@ expandAMAnnotations e = do case unfoldFunTypeWithRet $ typeOf f' of Nothing -> error "Function type expected." Just (ptypes, f_ret) -> - foldFunType (drop (length args') ptypes) f_ret + let parsubsts = mapMaybe parSub $ zip ptypes args' + in applySubst (`lookup` parsubsts) $ + foldFunType (drop (length args') $ map snd ptypes) f_ret pure $ mkApply f' (zip3 exts (repeat mempty) args') $ res {appResType = rettype} @@ -531,6 +533,10 @@ expandAMAnnotations e = do (Info res {appResType = stripArray (shapeRank $ autoFrame yam) (appResType res)}) _ -> astMap identityMapper {mapOnExp = expandAMAnnotations} e where + parSub ((Named v, Scalar (Prim (Signed Int64))), arg) = + Just (v, ExpSubst arg) + parSub _ = Nothing + setNewType e t = astMap identityMapper {mapOnStructType = const $ pure t} e funDiets :: TypeBase dim as -> [Diet] diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 506a8b715d..2ee02502c9 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -2031,7 +2031,7 @@ checkEntryArgs entry args entry_t "Got input of types" indent 2 (stack (map pretty args_ts)) where - (param_ts, _) = unfoldFunType entry_t + param_ts = map snd $ fst $ unfoldFunType entry_t args_ts = map (valueStructType . valueType) args expected | null param_ts = diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index aa284e0122..6cff2d91a4 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -526,19 +526,22 @@ foldFunType ps ret = -- | Extract the parameter types and return type from a type. -- If the type is not an arrow type, the list of parameter types is empty. -unfoldFunType :: TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness) -unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = +unfoldFunType :: TypeBase dim as -> ([(PName, TypeBase dim Diet)], TypeBase dim NoUniqueness) +unfoldFunType (Scalar (Arrow _ p d t1 (RetType _ t2))) = let (ps, r) = unfoldFunType t2 - in (second (const d) t1 : ps, r) + in ((p, second (const d) t1) : ps, r) unfoldFunType t = ([], toStruct t) -- | Extract the parameter types and 'RetTypeBase' from a function type. -- If the type is not an arrow type, returns 'Nothing'. -unfoldFunTypeWithRet :: TypeBase dim as -> Maybe ([TypeBase dim Diet], RetTypeBase dim Uniqueness) -unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 (RetType _ t2@(Scalar Arrow {})))) = do +unfoldFunTypeWithRet :: + TypeBase dim as -> + Maybe ([(PName, TypeBase dim Diet)], RetTypeBase dim Uniqueness) +unfoldFunTypeWithRet (Scalar (Arrow _ p d t1 (RetType _ t2@(Scalar Arrow {})))) = do (ps, r) <- unfoldFunTypeWithRet t2 - pure (second (const d) t1 : ps, r) -unfoldFunTypeWithRet (Scalar (Arrow _ _ d t1 r@RetType {})) = Just ([second (const d) t1], r) + pure ((p, second (const d) t1) : ps, r) +unfoldFunTypeWithRet (Scalar (Arrow _ p d t1 r@RetType {})) = + Just ([(p, second (const d) t1)], r) unfoldFunTypeWithRet _ = Nothing -- | The type scheme of a value binding, comprising the type diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index a69722ee2b..f87e280330 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -690,7 +690,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype where (RetType _ rettype_t) = rettype (rettype_params, rettype') = unfoldFunType rettype_t - param_ts = map patternType params ++ rettype_params + param_ts = map patternType params ++ map snd rettype_params checkValBind :: ValBindBase NoInfo Name -> TypeM (Env, ValBind) checkValBind vb = do diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index 8c92e54d20..2f194f9c85 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -810,7 +810,7 @@ checkExp (AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret), fu -- checkExp (AppExp (BinOp (op, oploc) opt (x, xp) (y, yp) loc) appres) = do op_als <- observeVar (locOf oploc) (qualLeaf op) (unInfo opt) - let at1 : at2 : _ = fst $ unfoldFunType op_als + let (_, at1) : (_, at2) : _ = fst $ unfoldFunType op_als (x', x_als) <- checkArg [] at1 mempty x (y', y_als) <- checkArg [(x', x_als)] at2 mempty y res_als <- checkFuncall loc (Just op) op_als [(mempty, x_als), (mempty, y_als)] diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 02126dbfd8..3070e273ca 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -60,7 +60,7 @@ mustBeExplicitInBinding :: StructType -> S.Set VName mustBeExplicitInBinding bind_t = let (ts, ret) = unfoldFunType bind_t alsoRet = M.unionWith (&&) $ M.fromList $ map (,True) (S.toList (fvVars (freeInType ret))) - in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map toStruct ts + in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map (toStruct . snd) ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. From 06d7232fb10884f1936e85d3904d04573ba3fd80 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 Mar 2024 14:31:43 +0100 Subject: [PATCH 210/296] Workaround for wrong return type handling. --- src/Futhark/Internalise/FullNormalise.hs | 36 +++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 2fe87f6244..3afecad4bb 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -501,22 +501,26 @@ transformProg = mapM transformValBind expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp expandAMAnnotations e = case e of - (AppExp (Apply f args loc) (Info res)) -> do - let ((exts, ams), arg_es) = first unzip $ unzip $ map (first unInfo) $ NE.toList args - f' <- expandAMAnnotations f - arg_es' <- mapM expandAMAnnotations arg_es - let diets = funDiets $ typeOf f - withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do - let rettype = - case unfoldFunTypeWithRet $ typeOf f' of - Nothing -> error "Function type expected." - Just (ptypes, f_ret) -> - let parsubsts = mapMaybe parSub $ zip ptypes args' - in applySubst (`lookup` parsubsts) $ - foldFunType (drop (length args') $ map snd ptypes) f_ret - pure $ - mkApply f' (zip3 exts (repeat mempty) args') $ - res {appResType = rettype} + (AppExp (Apply f args loc) (Info res)) + | ((exts, ams), arg_es) <- + first unzip $ unzip $ map (first unInfo) $ NE.toList args, + any (/= mempty) ams -> do + f' <- expandAMAnnotations f + arg_es' <- mapM expandAMAnnotations arg_es + let diets = funDiets $ typeOf f + withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do + let rettype = + case unfoldFunTypeWithRet $ typeOf f' of + Nothing -> error "Function type expected." + Just (ptypes, f_ret) -> + let parsubsts = mapMaybe parSub $ zip ptypes args' + in applySubst (`lookup` parsubsts) $ + foldFunType (drop (length args') $ map snd ptypes) f_ret + when (appResExt res /= []) $ + error "expandAMAnnotations: cannot handle existential yet." + pure $ + mkApply f' (zip3 exts (repeat mempty) args') $ + res {appResType = rettype} (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x y' <- expandAMAnnotations y From a1628f3c120987e781da824aa2ddafd20c560701 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 Mar 2024 14:31:50 +0100 Subject: [PATCH 211/296] Handle special case. --- src/Language/Futhark/Interpreter/Values.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index b3fb36ac8c..de3bd2468d 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE LambdaCase #-} + -- | The value representation used in the interpreter. -- -- Kept simple and free of unnecessary operational details (in @@ -210,7 +212,9 @@ toArray' rowshape vs = ValueArray shape (listArray (0, length vs - 1) vs) -- | Produce multidimensional array from a flat list of values. toArrayR :: [Int64] -> ValueShape -> [Value m] -> Value m -toArrayR [] _ = error "toArrayR: empty shape" +toArrayR [] _ = \case + [v] -> v + _ -> error "toArrayR: empty shape" toArrayR [_] elemshape = toArray' elemshape toArrayR (n : ns) elemshape = toArray (foldr ShapeDim elemshape (n : ns)) From 6ce8ed3258e42618784f97d8f11272cb5fa07ef0 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 2 Apr 2024 12:34:05 +0200 Subject: [PATCH 212/296] Must expand here. --- src/Language/Futhark/Interpreter.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 2ee02502c9..06acadc6e6 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -900,7 +900,7 @@ evalAppExp env (Apply f args loc) = do f' <- eval env f foldM apply' f' args' where - ft = typeOf f + ft = expandType env $ typeOf f apply' f' (v', am') = applyAM loc env (f', ft) am' v' evalArg' (Info (ext, am), x) = evalArg env x ext am evalAppExp env (Index e is loc) = do From 05311de15be790bccd10fea029dc35f82d501f9a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 15:11:40 +0200 Subject: [PATCH 213/296] Strangle warnings. --- src/Futhark/Internalise/Exps.hs | 4 +- src/Futhark/Internalise/FullNormalise.hs | 54 +++---- src/Futhark/Solve/BranchAndBound.hs | 3 +- src/Futhark/Solve/GLPK.hs | 5 +- src/Futhark/Solve/LP.hs | 129 +++++++---------- src/Futhark/Solve/Matrix.hs | 18 +-- src/Futhark/Solve/Simplex.hs | 135 +++++++++--------- .../Futhark/TypeChecker/Constraints.hs | 2 +- src/Language/Futhark/TypeChecker/Terms.hs | 4 +- .../Futhark/TypeChecker/Terms/Monad.hs | 3 - src/Language/Futhark/TypeChecker/Terms2.hs | 20 +-- 11 files changed, 153 insertions(+), 224 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1002b1bfa8..1669ecd5e7 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1583,7 +1583,7 @@ isIntrinsicFunction :: [E.Exp] -> SrcLoc -> Maybe (String -> InternaliseM [SubExp]) -isIntrinsicFunction qname args loc = do +isIntrinsicFunction qname all_args loc = do guard $ baseTag (qualLeaf qname) <= maxIntrinsicTag let handlers = [ handleSign, @@ -1593,7 +1593,7 @@ isIntrinsicFunction qname args loc = do handleAD, handleRest ] - msum [h args $ baseString $ qualLeaf qname | h <- handlers] + msum [h all_args $ baseString $ qualLeaf qname | h <- handlers] where handleSign [x] "sign_i8" = Just $ toSigned I.Int8 x handleSign [x] "sign_i16" = Just $ toSigned I.Int16 x diff --git a/src/Futhark/Internalise/FullNormalise.hs b/src/Futhark/Internalise/FullNormalise.hs index 97f8c58513..06ddb255f0 100644 --- a/src/Futhark/Internalise/FullNormalise.hs +++ b/src/Futhark/Internalise/FullNormalise.hs @@ -23,17 +23,14 @@ module Futhark.Internalise.FullNormalise (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor -import Data.Functor.Identity import Data.List (zip4) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T -import Debug.Trace import Futhark.MonadFreshNames import Futhark.Util.Pretty import Language.Futhark -import Language.Futhark.Pretty import Language.Futhark.Primitive (intValue) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types @@ -502,14 +499,14 @@ transformProg = mapM transformValBind expandAMAnnotations :: (MonadFreshNames m) => Exp -> m Exp expandAMAnnotations e = case e of - (AppExp (Apply f args loc) (Info res)) + (AppExp (Apply f args _) (Info res)) | ((exts, ams), arg_es) <- first unzip $ unzip $ map (first unInfo) $ NE.toList args, any (/= mempty) ams -> do f' <- expandAMAnnotations f arg_es' <- mapM expandAMAnnotations arg_es let diets = funDiets $ typeOf f - withMapNest loc (zip4 exts ams arg_es' diets) $ \args' -> do + withMapNest (zip4 exts ams arg_es' diets) $ \args' -> do let rettype = case unfoldFunTypeWithRet $ typeOf f' of Nothing -> error "Function type expected." @@ -525,7 +522,7 @@ expandAMAnnotations e = (AppExp (BinOp op (Info t) (x, Info (xext, xam)) (y, Info (yext, yam)) loc) (Info res)) -> do x' <- expandAMAnnotations x y' <- expandAMAnnotations y - withMapNest loc [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> + withMapNest [(xext, xam, x', Observe), (yext, yam, y', Observe)] $ \[x'', y''] -> pure $ AppExp ( BinOp @@ -542,31 +539,13 @@ expandAMAnnotations e = Just (v, ExpSubst arg) parSub _ = Nothing - setNewType e t = astMap identityMapper {mapOnStructType = const $ pure t} e - funDiets :: TypeBase dim as -> [Diet] funDiets (Scalar (Arrow _ _ d _ (RetType _ t2))) = d : funDiets t2 funDiets _ = [] - dropDims :: Int -> TypeBase dim as -> TypeBase dim as - dropDims n (Scalar (Arrow u p diet t1 (RetType ds t2))) = - Scalar (Arrow u p diet (stripArray n t1) (RetType ds (dropDims n t2))) - dropDims n t = stripArray n t - - innerFType :: TypeBase dim as -> [AutoMap] -> TypeBase dim as - innerFType (Scalar (Arrow u p diet t1 (RetType ds t2))) ams = - Scalar $ Arrow u p diet t1 $ RetType ds $ innerFType' t2 ams - where - innerFType' t [] = t - innerFType' (Scalar (Arrow u p diet t1 (RetType ds t2))) (am : ams) = - Scalar $ Arrow u p diet (dropDims (shapeRank (autoMap am)) t1) $ RetType ds $ innerFType' t2 ams - innerFType' t [am] = dropDims (shapeRank (autoMap am)) t - innerFType' _ _ = error "" - innerFType _ _ = error "" - type Level = Int -data AutoMapArg = AutoMapArg +newtype AutoMapArg = AutoMapArg { amArg :: Exp } deriving (Show) @@ -582,13 +561,12 @@ data AutoMapParam = AutoMapParam withMapNest :: forall m. (MonadFreshNames m) => - SrcLoc -> [(Maybe VName, AutoMap, Exp, Diet)] -> ([Exp] -> m Exp) -> m Exp -withMapNest loc args f = do +withMapNest nest_args f = do (param_map, arg_map) <- - bimap combineMaps combineMaps . unzip <$> mapM buildArgMap args + bimap combineMaps combineMaps . unzip <$> mapM buildArgMap nest_args buildMapNest param_map arg_map $ maximum $ M.keys arg_map where combineMaps :: (Ord k) => [M.Map k v] -> M.Map k [v] @@ -609,17 +587,17 @@ withMapNest loc args f = do args = map amArg $ arg_map M.! l body <- buildMapNest param_map arg_map (l - 1) pure $ - mkMap map_dim params body args $ + mkMap params body args $ RetType [] $ arrayOfWithAliases Nonunique (Shape [map_dim]) (typeOf body) buildArgMap :: (Maybe VName, AutoMap, Exp, Diet) -> m (M.Map Level AutoMapParam, M.Map Level AutoMapArg) - buildArgMap (ext, am, arg, diet) = - foldM (mkArgsAndParams arg) mempty $ reverse [0 .. trueLevel am] + buildArgMap (_ext, am, arg, arg_diet) = + foldM mkArgsAndParams mempty $ reverse [0 .. trueLevel am] where - mkArgsAndParams arg (p_map, a_map) l + mkArgsAndParams (p_map, a_map) l | l == 0 = do let arg' = maybe arg (paramToExp . amParam) (p_map M.!? 1) rarg <- mkReplicateShape (autoRep am `shapePrefix` autoFrame am) arg' @@ -628,7 +606,7 @@ withMapNest loc args f = do p <- mkAMParam (typeOf arg) l let d = outerDim am l pure - ( M.insert l (AutoMapParam p d diet) p_map, + ( M.insert l (AutoMapParam p d arg_diet) p_map, M.insert l (AutoMapArg arg) a_map ) | l < trueLevel am && l > 0 = do @@ -639,7 +617,7 @@ withMapNest loc args f = do amParam $ p_map M.! (l + 1) pure - ( M.insert l (AutoMapParam p d diet) p_map, + ( M.insert l (AutoMapParam p d arg_diet) p_map, M.insert l (AutoMapArg arg') a_map ) | otherwise = error "Impossible." @@ -672,13 +650,13 @@ mkReplicate :: (MonadFreshNames m) => Exp -> Exp -> m Exp mkReplicate dim e = do x <- mkParam "x" (Scalar $ Prim $ Unsigned Int64) pure $ - mkMap dim [(Observe, x)] e [xs] $ + mkMap [(Observe, x)] e [xs] $ RetType mempty (arrayOfWithAliases Unique (Shape [dim]) (typeOf e)) where xs = AppExp ( Range - (Literal (UnsignedValue $ intValue Int64 0) mempty) + (Literal (UnsignedValue $ intValue Int64 (0 :: Int)) mempty) Nothing (UpToExclusive dim) mempty @@ -686,8 +664,8 @@ mkReplicate dim e = do ( Info $ AppRes (arrayOf (Shape [dim]) (Scalar $ Prim $ Unsigned Int64)) [] ) -mkMap :: Exp -> [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp -mkMap dim params body arrs rettype = +mkMap :: [(Diet, Pat ParamType)] -> Exp -> [Exp] -> ResRetType -> Exp +mkMap params body arrs rettype = mkApply mapN args (AppRes (toStruct $ retType rettype) []) where args = map (Nothing,mempty,) $ lambda : arrs diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs index 846ae4a59a..258757113b 100644 --- a/src/Futhark/Solve/BranchAndBound.hs +++ b/src/Futhark/Solve/BranchAndBound.hs @@ -5,7 +5,6 @@ import Data.Maybe import Data.Set qualified as S import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.LP (LP (..)) import Futhark.Solve.Matrix import Futhark.Solve.Simplex @@ -54,7 +53,7 @@ branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt -- TODO: use isInt x = x == round x -- requires a better 'rowEchelon' implementation for matrices - isInt x = (abs (fromIntegral (round x) - x)) <= 10 ^^ (-10) + isInt x = abs (fromIntegral (round x :: Int) - x) <= 10 ^^ ((-10) :: Int) mkProblem = M.foldrWithKey ( \idx bound acc -> addBound acc idx bound diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs index b4b0b4602b..5c8f40fcd8 100644 --- a/src/Futhark/Solve/GLPK.hs +++ b/src/Futhark/Solve/GLPK.hs @@ -1,6 +1,7 @@ module Futhark.Solve.GLPK (glpk) where import Control.Monad +import Data.Bifunctor import Data.LinearProgram import Data.Map qualified as M import Data.Maybe @@ -8,7 +9,7 @@ import Data.Set qualified as S import Futhark.Solve.LP qualified as F import System.IO.Silently -linearProgToGLPK :: (Ord v, Eq a, Num a) => F.LinearProg v a -> (LP v a) +linearProgToGLPK :: (Ord v, Num a) => F.LinearProg v a -> LP v a linearProgToGLPK prog = LP { direction = cOptType $ F.optType prog, @@ -54,6 +55,6 @@ glpk' lp pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) | otherwise = do (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp - pure $ (\(opt, vs) -> (truncate opt, fmap truncate vs)) <$> mres + pure $ bimap truncate (fmap truncate) <$> mres where opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs index 47804b738b..5011ece9fb 100644 --- a/src/Futhark/Solve/LP.hs +++ b/src/Futhark/Solve/LP.hs @@ -32,23 +32,18 @@ module Futhark.Solve.LP ) where -import Control.Monad.LPMonad -import Data.Char (isAscii) -import Data.List qualified as L import Data.Map (Map) -import Data.Map qualified as Map +import Data.Map qualified as M import Data.Maybe import Data.Set (Set) import Data.Set qualified as S import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.Matrix (Matrix (..)) -import Futhark.Solve.Matrix qualified as M +import Futhark.Solve.Matrix qualified as Matrix import Futhark.Util.Pretty import Language.Futhark.Pretty import Prelude hiding (max, min, or) -import Prelude qualified -- | A linear program. 'LP c a d' represents the program -- @@ -79,24 +74,25 @@ data LPE a = LPE } deriving (Eq, Show) -rowEchelonLPE :: (Show a, Unbox a, Fractional a, Ord a) => LPE a -> LPE a +rowEchelonLPE :: (Unbox a, Fractional a, Ord a) => LPE a -> LPE a rowEchelonLPE (LPE c a d) = - LPE c (M.sliceCols (V.generate (ncols a) id) ad) (M.getCol (ncols a) ad) + LPE c (Matrix.sliceCols (V.generate (ncols a) id) ad) (Matrix.getCol (ncols a) ad) where ad = - M.filterRows (V.any (Prelude./= 0)) $ - (M.rowEchelon $ a M.<|> M.fromColVector d) + Matrix.filterRows + (V.any (Prelude./= 0)) + (Matrix.rowEchelon $ a Matrix.<|> Matrix.fromColVector d) -- | Converts an 'LP' into an equivalent 'LPE' by introducing slack -- variables. -convert :: (Show a, Num a, Unbox a) => LP a -> LPE a +convert :: (Num a, Unbox a) => LP a -> LPE a convert (LP c a d) = LPE c' a' d where - a' = a M.<|> M.diagonal (V.replicate (M.nrows a) 1) - c' = c V.++ V.replicate (M.nrows a) 0 + a' = a Matrix.<|> Matrix.diagonal (V.replicate (Matrix.nrows a) 1) + c' = c V.++ V.replicate (Matrix.nrows a) 0 -- | Linear sum of variables. -newtype LSum v a = LSum {lsum :: (Map (Maybe v) a)} +newtype LSum v a = LSum {lsum :: Map (Maybe v) a} deriving (Show, Eq) instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where @@ -108,10 +104,10 @@ instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where Nothing -> pretty a Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' ) - $ Map.toList m + $ M.toList m isConstant :: (Ord v) => LSum v a -> Bool -isConstant (LSum m) = Map.keysSet m `S.isSubsetOf` S.singleton Nothing +isConstant (LSum m) = M.keysSet m `S.isSubsetOf` S.singleton Nothing instance Functor (LSum v) where fmap f (LSum m) = LSum $ fmap f m @@ -120,7 +116,7 @@ class Vars a v where vars :: a -> Set v instance (Ord v) => Vars (LSum v a) v where - vars = S.fromList . catMaybes . Map.keys . lsum + vars = S.fromList . catMaybes . M.keys . lsum -- | Type of constraint data CType = Equal | LessEq @@ -159,7 +155,7 @@ data LinearProg v a = LinearProg instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where pretty (LinearProg opt obj cs) = - vcat $ + vcat [ pretty opt, indent 2 $ pretty obj, "subject to", @@ -172,10 +168,10 @@ instance (Ord v) => Vars (LinearProg v a) v where <> foldMap vars (constraints lp) bigM :: (Num a) => a -bigM = 2 ^ 10 +bigM = 2 ^ (10 :: Int) -- max{x, y} = z -max :: (Eq a, Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] +max :: (Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] max b x y z = [ z ~>=~ x, z ~>=~ y, @@ -184,7 +180,7 @@ max b x y z = ] -- min{x, y} = z -min :: (Eq a, Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +min :: (Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] min b x y z = [ var z ~<=~ var x, var z ~<=~ var y, @@ -192,7 +188,7 @@ min b x y z = var z ~>=~ var y ~-~ bigM ~*~ var b ] -oneIsZero :: (Eq a, Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] +oneIsZero :: (Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] oneIsZero (b1, x1) (b2, x2) = mkC b1 x1 <> mkC b2 x2 @@ -202,7 +198,7 @@ oneIsZero (b1, x1) (b2, x2) = [ var x ~<=~ bigM ~*~ var b ] -or :: (Eq a, Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] +or :: (Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] or b1 b2 c1 c2 = mkC b1 c1 <> mkC b2 c2 @@ -216,94 +212,89 @@ or b1 b2 c1 c2 = [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) ] -bin :: (Num a, Ord v) => v -> Constraint v a +bin :: (Num a) => v -> Constraint v a bin v = Constraint LessEq (var v) (constant 1) -(~==~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +(~==~) :: LSum v a -> LSum v a -> Constraint v a l ~==~ r = Constraint Equal l r infix 4 ~==~ -(~<=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +(~<=~) :: LSum v a -> LSum v a -> Constraint v a l ~<=~ r = Constraint LessEq l r infix 4 ~<=~ -(~>=~) :: (Num a, Ord v) => LSum v a -> LSum v a -> Constraint v a +(~>=~) :: (Num a) => LSum v a -> LSum v a -> Constraint v a l ~>=~ r = Constraint LessEq (neg l) (neg r) infix 4 ~>=~ normalize :: (Eq a, Num a) => LSum v a -> LSum v a -normalize = LSum . Map.filter (/= 0) . lsum +normalize = LSum . M.filter (/= 0) . lsum var :: (Num a) => v -> LSum v a -var v = LSum $ Map.singleton (Just v) (fromInteger 1) +var v = LSum $ M.singleton (Just v) 1 constant :: a -> LSum v a -constant = LSum . Map.singleton Nothing +constant = LSum . M.singleton Nothing cval :: (Num a, Ord v) => LSum v a -> a cval = (! Nothing) -(~+~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a --- (LSum x) ~+~ (LSum y) = normalize $ LSum $ Map.unionWith (+) x y -(LSum x) ~+~ (LSum y) = LSum $ Map.unionWith (+) x y +(~+~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +(LSum x) ~+~ (LSum y) = LSum $ M.unionWith (+) x y infixl 6 ~+~ -(~-~) :: (Eq a, Num a, Ord v) => LSum v a -> LSum v a -> LSum v a -x ~-~ y = x ~+~ (neg y) +(~-~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +x ~-~ y = x ~+~ neg y infixl 6 ~-~ -(~*~) :: (Eq a, Num a, Ord v) => a -> LSum v a -> LSum v a --- a ~*~ s = normalize $ fmap (a *) s +(~*~) :: (Num a) => a -> LSum v a -> LSum v a a ~*~ s = fmap (a *) s infixl 7 ~*~ (!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a -(LSum m) ! v = - case m Map.!? v of - Nothing -> 0 - Just a -> a +(LSum m) ! v = fromMaybe 0 (m M.!? v) -neg :: (Num a, Ord v) => LSum v a -> LSum v a +neg :: (Num a) => LSum v a -> LSum v a neg (LSum x) = LSum $ fmap negate x -- | Converts a linear program given with a list of constraints -- into the standard form. linearProgToLP :: forall v a. - (Unbox a, Num a, Ord v, Eq a) => + (Unbox a, Num a, Ord v) => LinearProg v a -> (LP a, Map Int v) linearProgToLP (LinearProg otype obj cs) = - (LP c a d, idxMap) + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LP c a d, idxMap) where cs' = foldMap (convertEqCType . splitConstraint) cs idxMap = - Map.fromList $ + M.fromList $ zip [0 ..] $ catMaybes $ - Map.keys $ + M.keys $ mconcat $ map (lsum . fst) cs' - mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) - c = mkRow $ convertObj otype obj - a = M.fromVectors $ map (mkRow . fst) cs' - d = V.fromList $ map snd cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) + + convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] + convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] + convertEqCType (LessEq, s, a) = [(s, a)] splitConstraint :: Constraint v a -> (CType, LSum v a, a) splitConstraint (Constraint ctype l r) = let c = negate $ cval (l ~-~ r) in (ctype, l ~-~ r ~-~ constant c, c) - convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] - convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] - convertEqCType (LessEq, s, a) = [(s, a)] - convertObj :: OptType -> LSum v a -> LSum v a convertObj Maximize s = s convertObj Minimize s = neg s @@ -312,24 +303,24 @@ linearProgToLP (LinearProg otype obj cs) = -- into the equational form. Assumes no <= constraints. linearProgToLPE :: forall v a. - (Unbox a, Num a, Ord v, Eq a) => + (Unbox a, Num a, Ord v) => LinearProg v a -> (LPE a, Map Int v) linearProgToLPE (LinearProg otype obj cs) = - (LPE c a d, idxMap) + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LPE c a d, idxMap) where cs' = map (checkOnlyEqType . splitConstraint) cs idxMap = - Map.fromList $ + M.fromList $ zip [0 ..] $ catMaybes $ - Map.keys $ + M.keys $ mconcat $ map (lsum . fst) cs' - mkRow s = V.generate (Map.size idxMap) $ \i -> s ! Just (idxMap Map.! i) - c = mkRow $ convertObj otype obj - a = M.fromVectors $ map (mkRow . fst) cs' - d = V.fromList $ map snd cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) splitConstraint :: Constraint v a -> (CType, LSum v a, a) splitConstraint (Constraint ctype l r) = @@ -343,15 +334,3 @@ linearProgToLPE (LinearProg otype obj cs) = convertObj :: OptType -> LSum v a -> LSum v a convertObj Maximize s = s convertObj Minimize s = neg s - -test1 :: LPE Double -test1 = - LPE - { pc = V.fromList [5.5, 2.1], - pA = - M.fromLists - [ [-1, 1], - [8, 2] - ], - pd = V.fromList [2, 17] - } diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs index ae3bdf6b7c..39ec16a39e 100644 --- a/src/Futhark/Solve/Matrix.hs +++ b/src/Futhark/Solve/Matrix.hs @@ -36,7 +36,6 @@ module Futhark.Solve.Matrix where import Data.List qualified as L -import Data.Map (Map) import Data.Map qualified as M import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V @@ -84,7 +83,7 @@ fromVectors :: (Unbox a) => [Vector a] -> Matrix a fromVectors [] = empty fromVectors vs = Matrix - { elems = V.concat $ vs, + { elems = V.concat vs, nrows = length vs, ncols = V.length $ head vs } @@ -263,18 +262,7 @@ update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a update m upds = generate ( \i j -> - case (M.fromList $ V.toList upds) M.!? (i, j) of - Nothing -> m ! (i, j) - Just x -> x - ) - (nrows m) - (ncols m) - -update_ :: (Unbox a) => Matrix a -> Map (Int, Int) a -> Matrix a -update_ m upds = - generate - ( \i j -> - case upds M.!? (i, j) of + case M.fromList (V.toList upds) M.!? (i, j) of Nothing -> m ! (i, j) Just x -> x ) @@ -282,7 +270,7 @@ update_ m upds = (ncols m) -- This version doesn't maintain integrality of the entries. -rowEchelon :: (Num a, Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a +rowEchelon :: (Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a rowEchelon = rowEchelon' 0 0 where rowEchelon' h k m@(Matrix _ nr nc) diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs index e01c7ce566..362b300038 100644 --- a/src/Futhark/Solve/Simplex.hs +++ b/src/Futhark/Solve/Simplex.hs @@ -12,7 +12,6 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Vector.Unboxed (Unbox, Vector) import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) import Futhark.Solve.Matrix @@ -30,35 +29,35 @@ import Futhark.Solve.Matrix -- equal to @z'@. -- | Computes @r@ as given in the tableau above. -comp_r :: +compR :: (Num a, Unbox a) => LPE a -> Matrix a -> Vector Int -> Vector Int -> Vector a -comp_r (LPE c a _) invA_B b n = +compR (LPE c a _) invA_B b n = c @ n .-. c @ b .* invA_B .* a @ n --- | @comp_q_enter prob invA_B b n enter@ computes the @enter@th +-- | @compQEnter prob invA_B b n enter@ computes the @enter@th -- column of @q@. -comp_q_enter :: +compQEnter :: (Num a, Unbox a) => LPE a -> Matrix a -> Int -> Vector a -comp_q_enter (LPE _ a _) invA_B enter = +compQEnter (LPE _ a _) invA_B enter = V.map negate $ invA_B *. getCol enter a -- | Computes the objective given an inversion of @a@ and a basis. -comp_z :: +compZ :: (Num a, Unbox a) => LPE a -> Matrix a -> Vector Int -> a -comp_z (LPE c _ d) invA_B b = +compZ (LPE c _ d) invA_B b = c @ b .* invA_B <.> d -- | Constructs an auxiliary equational linear program to compute the @@ -75,6 +74,57 @@ mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) b_aux = V.generate (nrows a) (+ ncols a) n_aux = V.generate (ncols a) id +fixDegenerateBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Int -> + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +fixDegenerateBasis og_prob col prob (invA_B, p, b, n) + | Just exit_idx <- mexit_idx, + V.null (elim_row exit_idx) = + let prob' = + prob + { pA = deleteRow exit_idx (pA prob), + pd = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) $ + pd prob + } + invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B + p' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) p + b' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) b + in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) + | Just exit_idx <- mexit_idx, + (enter, _) <- V.head (elim_row exit_idx) = + let enter_idx = fromJust $ V.findIndex (== enter) n + exit = b V.! exit_idx + in fixDegenerateBasis og_prob col prob $ + pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = + let prob' = + prob + { pc = pc og_prob, + pA = sliceCols (V.generate col id) $ pA prob, + pd = V.map abs $ pd og_prob + } + in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) + where + mexit_idx = + fst <$> V.filter ((>= col) . snd) (V.imap (curry id) b) V.!? 0 + elim_row exit_idx = + V.filter ((/= 0) . snd) $ + V.map (\j -> (j, compQEnter prob invA_B j V.! exit_idx)) $ + V.generate col id + -- | Finds an initial feasible basis for an equational linear program. -- Returns 'Nothing' if the LP has no solution. Inverts some -- equations by multiplying by -1 so it also returns a modified (but @@ -85,64 +135,13 @@ findBasis :: Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) findBasis prob = do (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) - if comp_z p_aux invA_B b == 0 + if compZ p_aux invA_B b == 0 then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) else Nothing where (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob invA_B_aux = identity $ V.length b_aux - fixDegenerateBasis :: - (Unbox a, Ord a, Fractional a, Show a) => - LPE a -> - Int -> - LPE a -> - (Matrix a, Vector a, Vector Int, Vector Int) -> - (LPE a, Matrix a, Vector a, Vector Int, Vector Int) - fixDegenerateBasis og_prob col prob (invA_B, p, b, n) - | Just exit_idx <- mexit_idx, - V.null (elim_row exit_idx) = - let prob' = - prob - { pA = deleteRow exit_idx (pA prob), - pd = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) $ - pd prob - } - invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B - p' = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) p - b' = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) b - in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) - | Just exit_idx <- mexit_idx, - (enter, _) <- V.head (elim_row exit_idx) = - let enter_idx = fromJust $ V.findIndex (== enter) n - exit = b V.! exit_idx - in fixDegenerateBasis og_prob col prob $ - pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) - | otherwise = - let prob' = - prob - { pc = pc og_prob, - pA = sliceCols (V.generate col id) $ pA prob, - pd = V.map abs $ pd og_prob - } - in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) - where - mexit_idx = - fst <$> ((V.filter ((>= col) . snd) (V.imap (curry id) b)) V.!? 0) - elim_row exit_idx = - V.filter ((/= 0) . snd) $ - V.map (\j -> (j, comp_q_enter prob invA_B j V.! exit_idx)) $ - V.generate col id - -- | Solves an equational linear program. Returns 'Nothing' if the -- program is infeasible or unbounded. Otherwise returns the optimal -- value and the solution. @@ -151,10 +150,9 @@ simplex :: LPE a -> Maybe (a, Vector a) simplex lpe = do - let ech_lpe = rowEchelonLPE lpe - res@(lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe + (lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) - let z = comp_z lpe' invA_B' b' + let z = compZ lpe' invA_B' b' sol = V.map snd $ V.fromList $ @@ -180,12 +178,12 @@ simplexProg :: Maybe (a, Map v a) simplexProg prog = do (z, sol) <- simplex lpe - pure $ (z, M.fromList $ map (\(i, x) -> (idxMap M.! i, x)) $ zip [0 ..] $ V.toList sol) + pure (z, M.fromList $ zipWith (\i x -> (idxMap M.! i, x)) [0 ..] $ V.toList sol) where (lpe, idxMap) = linearProgToLPE prog pivot :: - (Unbox a, Ord a, Fractional a, Show a) => + (Unbox a, Fractional a) => LPE a -> (Matrix a, Vector a, Vector Int, Vector Int) -> (Int, Int) -> @@ -194,7 +192,7 @@ pivot :: pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = (invA_B', p', b', n') where - q_enter = comp_q_enter prob invA_B enter + q_enter = compQEnter prob invA_B enter b' = b V.// [(exit_idx, enter)] n' = n V.// [(enter_idx, exit)] e_inv_vec = @@ -216,7 +214,7 @@ step :: step prob (invA_B, p, b, n) | Just enter_idx <- menter_idx = let enter = n V.! enter_idx - q_enter = comp_q_enter prob invA_B enter + q_enter = compQEnter prob invA_B enter pq = V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ V.filter (\(_, _, q_) -> q_ < 0) $ @@ -233,6 +231,5 @@ step prob (invA_B, p, b, n) in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) | otherwise = Just (invA_B, p, b, n) where - r = comp_r prob invA_B b n + r = compR prob invA_B b n menter_idx = V.findIndex (> 0) r - b_zero = V.filter (\(v, i) -> v == 0 && (not $ V.null (V.filter (< i) n))) $ V.zip p b diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e6759a9ff7..2bee66d8a6 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -213,7 +213,7 @@ unifySharedFields reason fs1 fs2 = solveEq reason ts1 ts2 mustSupportEql :: Reason -> Type -> SolveM () -mustSupportEql reason t = pure () +mustSupportEql _reason _t = pure () -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 75dd2947ea..73a7315356 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1046,7 +1046,7 @@ checkOneExp e = do (maybe_tysubsts, e') <- Terms2.checkSingleExp e case maybe_tysubsts of Left err -> throwError err - Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' let t = typeOf e'' (tparams, _, _) <- @@ -1064,7 +1064,7 @@ checkSizeExp e = do (maybe_tysubsts, e') <- Terms2.checkSizeExp e case maybe_tysubsts of Left err -> throwError err - Right (generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do + Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do e'' <- checkExp e' when (hasBinding e'') $ typeError (srclocOf e'') mempty . withIndexLink "size-expression-bind" $ diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index b8371cbfe8..fc1733df5e 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -60,7 +60,6 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified -import Futhark.Util import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals @@ -533,8 +532,6 @@ instance MonadTypeChecker TermTypeM where lookupVar :: SrcLoc -> QualName VName -> StructType -> TermTypeM StructType lookupVar loc qn@(QualName qs name) inst_t = do scope <- lookupQualNameEnv qn - let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) - case M.lookup name $ scopeVtable scope of Nothing -> error $ "lookupVar: " <> show qn diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cfa88069ca..4a65aedf91 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -249,7 +249,7 @@ newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniquen newTypeOverloaded loc name pts = tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) -newSVar :: (Located loc) => loc -> Name -> TermM SVar +newSVar :: loc -> Name -> TermM SVar newSVar _loc desc = do i <- incCounter newID $ mkTypeVarName desc i @@ -426,7 +426,7 @@ lookupVar loc qn@(QualName qs name) = do if null tparams && null qs then pure t else do - (tnames, t') <- instTypeScheme qn loc tparams t + (_tnames, t') <- instTypeScheme qn loc tparams t -- TODO - qualify type names, like in the old type checker. pure t' Just EqualityF -> do @@ -707,22 +707,12 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} ) where - -- stripFrame :: Shape Size -> Type -> Type - -- stripFrame frame (Array u ds t) = - -- let mnew_shape = Shape <$> L.stripPrefix (toSComp <$> shapeDims frame) (shapeDims ds) - -- in case mnew_shape of - -- Nothing -> Scalar t - -- Just new_shape -> arrayOfWithAliases u new_shape $ Scalar t - -- stripFrame _ t = t - - isFunType (Scalar Arrow {}) = True - isFunType _ = False -- (fix) toSComp (Var (QualName [] x) _ _) = SVar x toSComp _ = error "" toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) - split (Array u s t) = do + split (Array _u s t) = do (a, b) <- split $ Scalar t pure (arrayOf s a, arrayOf s b) split ftype' = do @@ -1373,7 +1363,7 @@ checkSingleExp e = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars artificial <- gets termArtificial - ((cts', artificial', tyvars'), _, e'') <- + ((cts', _artificial', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' case solve cts' tyvars' of Left err -> pure (Left err, e'') @@ -1396,7 +1386,7 @@ checkSizeExp e = runTermM $ do (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- - forM cts_tyvars' $ \(cts', artificial', tyvars') -> + forM cts_tyvars' $ \(cts', _artificial', tyvars') -> bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' case (solutions, es') of From 4cbdb8de325b126666b2c0197078742f85051412 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 16:05:31 +0200 Subject: [PATCH 214/296] Warning-free tests. --- .../Futhark/Solve/BranchAndBoundTests.hs | 57 +++++++++---------- unittests/Futhark/Solve/SimplexTests.hs | 48 ++++++---------- 2 files changed, 45 insertions(+), 60 deletions(-) diff --git a/unittests/Futhark/Solve/BranchAndBoundTests.hs b/unittests/Futhark/Solve/BranchAndBoundTests.hs index ed7e04c715..b7e1bfe027 100644 --- a/unittests/Futhark/Solve/BranchAndBoundTests.hs +++ b/unittests/Futhark/Solve/BranchAndBoundTests.hs @@ -1,3 +1,5 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + module Futhark.Solve.BranchAndBoundTests ( tests, ) @@ -10,7 +12,6 @@ import Futhark.Solve.Matrix qualified as M import Test.Tasty import Test.Tasty.HUnit import Prelude hiding (or) -import Prelude qualified tests :: TestTree tests = @@ -68,31 +69,28 @@ tests = case branchAndBound lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (11.8 :: Double), - and $ zipWith (==) (V.toList sol) [1, 3] - ], - testCase "5" $ - let prog = - LinearProg - { optType = Maximize, - objective = var "x1" ~+~ var "x2", - constraints = - [ var "x1" ~<=~ constant 10, - var "x2" ~<=~ constant 5 - ] - <> oneIsZero ("b1", "x1") ("b2", "x2") - } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp - in assertBool - (unlines [show $ branchAndBound lp]) - $ case branchAndBound lp of - Nothing -> False - Just (z, sol) -> - and - [ z `approxEq` (10 :: Double) - ], + (z `approxEq` (11.8 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 3]), + -- testCase "5" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> oneIsZero ("b1", "x1") ("b2", "x2") + -- } + -- (lp, _idxmap) = linearProgToLP prog + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, _sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ], -- testCase "6" $ -- let prog = -- LinearProg @@ -130,17 +128,16 @@ tests = var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ branchAndBound lp]) $ case branchAndBound lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (0 :: Double) ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool -approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/unittests/Futhark/Solve/SimplexTests.hs b/unittests/Futhark/Solve/SimplexTests.hs index 1a52203d12..c29bd10a93 100644 --- a/unittests/Futhark/Solve/SimplexTests.hs +++ b/unittests/Futhark/Solve/SimplexTests.hs @@ -1,17 +1,17 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + module Futhark.Solve.SimplexTests ( tests, ) where import Data.Vector.Unboxed qualified as V -import Debug.Trace import Futhark.Solve.LP import Futhark.Solve.Matrix qualified as M import Futhark.Solve.Simplex import Test.Tasty import Test.Tasty.HUnit import Prelude hiding (or) -import Prelude qualified tests :: TestTree tests = @@ -69,10 +69,8 @@ tests = case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (14.08 :: Double), - and $ zipWith approxEq (V.toList sol) [1.3, 3.3] - ], + (z `approxEq` (14.08 :: Double)) + && and (zipWith approxEq (V.toList sol) [1.3, 3.3]), testCase "5" $ let lp = LP @@ -88,10 +86,8 @@ tests = case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (0 :: Double), - and $ zipWith approxEq (V.toList sol) [0] - ], + (z `approxEq` (0 :: Double)) + && and (zipWith approxEq (V.toList sol) [0]), testCase "6" $ let lp = LP @@ -107,10 +103,8 @@ tests = case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (5 :: Double), - and $ zipWith approxEq (V.toList sol) [5] - ], + z `approxEq` (5 :: Double) + && and (zipWith approxEq (V.toList sol) [5]), testCase "7" $ let prog = LinearProg @@ -121,17 +115,14 @@ tests = var "b1" ~+~ var "b2" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False Just (z, sol) -> - and - [ z `approxEq` (10 :: Double), - and $ zipWith (==) (V.toList sol) [1, 0, 10] - ], + (z `approxEq` (10 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 0, 10]), testCase "8" $ let prog = LinearProg @@ -143,13 +134,12 @@ tests = ] <> oneIsZero ("b1", "x1") ("b2", "x2") } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (15 :: Double) ], @@ -192,13 +182,12 @@ tests = var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (0 :: Double) ], @@ -217,17 +206,16 @@ tests = var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 ] } - (lp, idxmap) = linearProgToLP prog - lpe = convert lp + (lp, _idxmap) = linearProgToLP prog in assertBool (unlines [show $ simplexLP lp]) $ case simplexLP lp of Nothing -> False - Just (z, sol) -> + Just (z, _sol) -> and [ z `approxEq` (0 :: Double) ] ] approxEq :: (Fractional a, Ord a) => a -> a -> Bool -approxEq x1 x2 = (abs $ x1 - x2) < 10 ^^ (-10 :: Int) +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) From a4946a133b80b015cc6ed0f7e8df40e977d44e89 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 16:11:16 +0200 Subject: [PATCH 215/296] Link against static glpk. --- default.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/default.nix b/default.nix index 17321c27d6..b3f82aac07 100644 --- a/default.nix +++ b/default.nix @@ -75,6 +75,7 @@ let "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.zlib.static}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" + "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { configureFlags = ["--enable-static"] ++ old.configureFlags;})}/lib" ]; preBuild = '' From 9cf8ab8b68ee1e68e8f938a0055507ff47ce9848 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 May 2024 16:20:04 +0200 Subject: [PATCH 216/296] This is cleaner. --- default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/default.nix b/default.nix index b3f82aac07..46d1fdcf3c 100644 --- a/default.nix +++ b/default.nix @@ -75,7 +75,7 @@ let "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.zlib.static}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" - "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { configureFlags = ["--enable-static"] ++ old.configureFlags;})}/lib" + "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { dontDisableStatic = true; })}/lib" ]; preBuild = '' From 323922fdf50554daa7565b94224c31872b683e43 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 6 Jun 2024 18:28:34 +0200 Subject: [PATCH 217/296] Fix typo. --- default.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/default.nix b/default.nix index 026f95049e..4d27d336c9 100644 --- a/default.nix +++ b/default.nix @@ -38,7 +38,7 @@ let haskellPackagesNew.callPackage ./nix/zlib.nix {zlib=pkgs.zlib;}; gasp = - haskellPackagesNew.callPackage ./nix/.nix {}; + haskellPackagesNew.callPackage ./nix/gasp.nix {}; glpk-hs = haskellPackagesNew.callPackage ./nix/glpk-hs.nix {}; From 8a808e037afcbf61a221d96220ea584bdafcbaa3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 8 Jun 2024 20:15:44 +0200 Subject: [PATCH 218/296] let should not be generalised. --- tests/types/inference5.fut | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 tests/types/inference5.fut diff --git a/tests/types/inference5.fut b/tests/types/inference5.fut deleted file mode 100644 index 900704f21a..0000000000 --- a/tests/types/inference5.fut +++ /dev/null @@ -1,7 +0,0 @@ --- Inference for a local function. --- == --- input { 2 } output { 4 } - -def main x = - let apply f x = f x - in apply (apply (i32.+) x) x From e71f281db0622fd3f82fe1f9183282cc0c3a5258 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 8 Jun 2024 22:33:19 +0200 Subject: [PATCH 219/296] Minor refactoring. --- .../Futhark/TypeChecker/Constraints.hs | 150 ++++++++++-------- 1 file changed, 88 insertions(+), 62 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 240f8ca765..f6a8e2be8e 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -115,30 +115,39 @@ instance Located TyVarInfo where type TyVar = VName --- | If a VName is not in this map, it is assumed to be rigid. The --- integer is the level. -type TyVars = M.Map TyVar (Int, TyVarInfo) +-- | The level at which a type variable is bound. Higher means +-- deeper. We can only unify a type variable at level @i@ with a type +-- @t@ if all type names that occur in @t@ are at most at level @i@. +type Level = Int + +-- | If a VName is not in this map, it is assumed to be rigid. +type TyVars = M.Map TyVar (Level, TyVarInfo) data TyVarSol = -- | Has been substituted with this. - TyVarSol Int Type - | -- | Replaced by this other type variable. - TyVarLink VName + TyVarSol Level Type | -- | Not substituted yet; has this constraint. - TyVarUnsol Int TyVarInfo + TyVarUnsol Level TyVarInfo deriving (Show) -newtype SolverState = SolverState {solverTyVars :: M.Map TyVar TyVarSol} +tyVarSolLevel :: TyVarSol -> Level +tyVarSolLevel (TyVarSol lvl _) = lvl +tyVarSolLevel (TyVarUnsol lvl _) = lvl + +newtype SolverState = SolverState + { -- | Left means linked to this other type variable. + solverTyVars :: M.Map TyVar (Either VName TyVarSol) + } initialState :: TyVars -> SolverState -initialState tyvars = SolverState $ M.map (uncurry TyVarUnsol) tyvars +initialState tyvars = SolverState $ M.map (Right . uncurry TyVarUnsol) tyvars -substTyVar :: (Monoid u) => M.Map TyVar TyVarSol -> VName -> Maybe (TypeBase SComp u) +substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = case M.lookup v m of - Just (TyVarLink v') -> substTyVar m v' - Just (TyVarSol _ t') -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' - Just (TyVarUnsol {}) -> Nothing + Just (Left v') -> substTyVar m v' + Just (Right (TyVarSol _ t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing lookupTyVar :: TyVar -> SolveM (Int, Either TyVarInfo Type) @@ -146,9 +155,9 @@ lookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of Nothing -> error $ "Unknown tyvar: " <> prettyNameString v - Just (TyVarSol lvl t) -> pure (lvl, Right t) - Just (TyVarLink v') -> f v' - Just (TyVarUnsol lvl info) -> pure (lvl, Left info) + Just (Left v') -> f v' + Just (Right (TyVarSol lvl t)) -> pure (lvl, Right t) + Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig -- | A solution maps a type variable to its substitution. This @@ -162,15 +171,15 @@ solution s = M.mapMaybe mkSubst $ solverTyVars s ) where - mkSubst (TyVarSol _lvl t) = + mkSubst (Right (TyVarSol _lvl t)) = Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t - mkSubst (TyVarLink v') = + mkSubst (Left v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (TyVarUnsol _ (TyVarPrim _ pts)) = Just $ Left pts + mkSubst (Right (TyVarUnsol _ (TyVarPrim _ pts))) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, TyVarUnsol _ (TyVarFree _)) = Just v + unconstrained (v, Right (TyVarUnsol _ (TyVarFree _))) = Just v unconstrained _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} @@ -216,15 +225,27 @@ unifySharedFields reason fs1 fs2 = mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () +scopeViolation :: Reason -> VName -> Type -> SolveM a +scopeViolation reason v tp = + throwError . TypeError (locOf reason) mempty $ + "Cannot unify type" + indent 2 (pretty tp) + "with" + <+> dquotes (prettyName v) + <+> "(scope violation)." + "This is because" + <+> dquotes (prettyName v) + <+> "is rigidly bound in a deeper scope." + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars case (v_info, t) of - (Just (TyVarUnsol _ TyVarFree {}), _) -> + (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( Just (Right (TyVarUnsol _ (TyVarPrim _ v_pts))), _ ) -> if t `elem` map (Scalar . Prim) v_pts @@ -235,7 +256,7 @@ subTyVar reason v lvl t = do indent 2 (pretty v_pts) "with" indent 2 (pretty t) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), Scalar (Sum cs2) ) -> if all (`elem` M.keys cs2) (M.keys cs1) @@ -246,7 +267,7 @@ subTyVar reason v lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Sum cs2)) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), _ ) -> throwError . TypeError (locOf reason) mempty $ @@ -254,7 +275,7 @@ subTyVar reason v lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty t) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), Scalar (Record fs2) ) -> if all (`elem` M.keys fs2) (M.keys fs1) @@ -265,7 +286,7 @@ subTyVar reason v lvl t = do indent 2 (pretty (Record fs1)) "with record type" indent 2 (pretty (Record fs2)) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), _ ) -> throwError . TypeError (locOf reason) mempty $ @@ -273,43 +294,48 @@ subTyVar reason v lvl t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty t) - (Just (TyVarUnsol _ (TyVarEql _)), _) -> + (Just (Right (TyVarUnsol _ (TyVarEql _))), _) -> mustSupportEql reason t -- -- Internal error cases - (Just TyVarSol {}, _) -> + (Just (Right TyVarSol {}), _) -> error $ "Type variable already solved: " <> prettyNameString v - (Just TyVarLink {}, _) -> + (Just Left {}, _) -> error $ "Type variable already linked: " <> prettyNameString v (Nothing, _) -> error $ "linkTyVar: Nothing v: " <> prettyNameString v setInfo v (TyVarSol lvl t) +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + setInfo :: TyVar -> TyVarSol -> SolveM () -setInfo v info = modify $ \s -> s {solverTyVars = M.insert v info $ solverTyVars s} +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} -- Precondition: 'v' is currently flexible and 't' has no solution. linkTyVar :: Reason -> VName -> VName -> SolveM () linkTyVar reason v t = do - occursCheck reason v $ Scalar $ TypeVar NoUniqueness (qualName t) [] - v_info <- gets $ M.lookup v . solverTyVars + v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars (lvl, t') <- lookupTyVar t + let tp = Scalar $ TypeVar NoUniqueness (qualName t) [] + occursCheck reason v tp + case (v_info, t') of -- When either is completely unconstrained. - (Just (TyVarUnsol _ TyVarFree {}), _) -> + (TyVarUnsol _ TyVarFree {}, _) -> pure () - ( Just (TyVarUnsol _ info), + ( TyVarUnsol _ info, Left (TyVarFree {}) ) -> setInfo t (TyVarUnsol lvl info) -- -- TyVarPrim cases - ( Just (TyVarUnsol _ info@TyVarPrim {}), + ( TyVarUnsol _ info@TyVarPrim {}, Left TyVarEql {} ) -> setInfo t (TyVarUnsol lvl info) - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( TyVarUnsol _ (TyVarPrim _ v_pts), Left (TyVarPrim t_loc t_pts) ) -> let pts = L.intersect v_pts t_pts @@ -321,14 +347,14 @@ linkTyVar reason v t = do "with type that must be one of" indent 2 (pretty t_pts) else setInfo t (TyVarUnsol lvl (TyVarPrim t_loc pts)) - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarRecord {} ) -> throwError . TypeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." - ( Just (TyVarUnsol _ (TyVarPrim _ v_pts)), + ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarSum {} ) -> throwError . TypeError (locOf reason) mempty $ @@ -337,19 +363,19 @@ linkTyVar reason v t = do "with type that must be sum." -- -- TyVarSum cases - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarSum loc cs2) ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 setInfo t (TyVarUnsol lvl (TyVarSum loc cs3)) - ( Just (TyVarUnsol _ TyVarSum {}), + ( TyVarUnsol _ TyVarSum {}, Left (TyVarPrim _ pts) ) -> throwError . TypeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarRecord _ fs) ) -> throwError . TypeError (locOf reason) mempty $ @@ -357,25 +383,25 @@ linkTyVar reason v t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Scalar (Record fs))) - ( Just (TyVarUnsol _ (TyVarSum _ cs1)), + ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarEql _) ) -> mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarRecord loc fs2) ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 setInfo t (TyVarUnsol lvl (TyVarRecord loc fs3)) - ( Just (TyVarUnsol _ TyVarRecord {}), + ( TyVarUnsol _ TyVarRecord {}, Left (TyVarPrim _ pts) ) -> throwError . TypeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarSum _ cs) ) -> throwError . TypeError (locOf reason) mempty $ @@ -383,33 +409,33 @@ linkTyVar reason v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty (Scalar (Sum cs))) - ( Just (TyVarUnsol _ (TyVarRecord _ fs1)), + ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarEql _) ) -> mapM_ (mustSupportEql reason) fs1 -- -- TyVarEql cases - (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarPrim {}) -> + (TyVarUnsol _ (TyVarEql _), Left TyVarPrim {}) -> pure () - (Just (TyVarUnsol _ (TyVarEql _)), Left TyVarEql {}) -> + (TyVarUnsol _ (TyVarEql _), Left TyVarEql {}) -> pure () - (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarRecord _ fs)) -> + (TyVarUnsol _ (TyVarEql _), Left (TyVarRecord _ fs)) -> mustSupportEql reason $ Scalar $ Record fs - (Just (TyVarUnsol _ (TyVarEql _)), Left (TyVarSum _ cs)) -> + (TyVarUnsol _ (TyVarEql _), Left (TyVarSum _ cs)) -> mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases - (Just TyVarSol {}, _) -> - error $ "Type variable already solved: " <> prettyNameString v - (Just TyVarLink {}, _) -> - error $ "Type variable already linked: " <> prettyNameString v - (Nothing, _) -> - error $ "linkTyVar: Nothing v: " <> prettyNameString v + (TyVarSol {}, _) -> + alreadySolved (_, Right t'') -> error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' -- Finally insert the actual link. - setInfo v (TyVarLink t) + setLink v t + where + unknown = error $ "linkTyVar: Nothing v: " <> prettyNameString v + alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v + alreadySolved = error $ "Type variable already solved: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -459,14 +485,14 @@ solveEq reason orig_t1 orig_t2 = do solveCt' (t1, t2) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of - Just (TyVarLink v') -> flexible v' - Just (TyVarUnsol lvl _) -> Just lvl - Just (TyVarSol _ _) -> Nothing + Just (Left v') -> flexible v' + Just (Right (TyVarUnsol lvl _)) -> Just lvl + Just (Right (TyVarSol _ _)) -> Nothing Nothing -> Nothing sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of - Just (TyVarLink v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (TyVarSol _ t') -> sub t' + Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) + Just (Right (TyVarSol _ t')) -> sub t' _ -> t sub t = t case (sub t1, sub t2) of From 0d7e6614f6699b6cce0b9fcb791749bbe16b657a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 3 Jul 2024 10:05:49 +0200 Subject: [PATCH 220/296] Crudely strangle warnings. --- src/Language/Futhark/TypeChecker/Constraints.hs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index f6a8e2be8e..24521c0327 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -11,6 +11,9 @@ module Language.Futhark.TypeChecker.Constraints TyVars, Solution, solve, + -- To hide warnings + tyVarSolLevel, + scopeViolation, ) where From f126abea1519a017d0da1ecf83a7944bcbb8119f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 5 Jul 2024 16:18:40 +0200 Subject: [PATCH 221/296] Track explicit type parameters in constraint solver. I cannot figure out whether this is hacky or OK, but we need it to handle level checks correctly. --- .../Futhark/TypeChecker/Constraints.hs | 39 ++++++++++++++----- src/Language/Futhark/TypeChecker/Terms2.hs | 32 +++++++++++---- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 24521c0327..3eb1f9d010 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -9,6 +9,7 @@ module Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVar, TyVars, + TyParams, Solution, solve, -- To hide warnings @@ -86,8 +87,8 @@ instance Pretty Ct where type Constraints = [Ct] --- | Information about a type variable. Every type variable is --- associated with a location, which is the original syntax element +-- | Information about a flexible type variable. Every type variable +-- is associated with a location, which is the original syntax element -- that it is the type of. data TyVarInfo = -- | Can be substituted with anything. @@ -123,18 +124,26 @@ type TyVar = VName -- @t@ if all type names that occur in @t@ are at most at level @i@. type Level = Int --- | If a VName is not in this map, it is assumed to be rigid. +-- | If a VName is not in this map, it should be in the 'TyParams' - +-- the exception is abstract types, which are just missing (and +-- assumed to have smallest possible level). type TyVars = M.Map TyVar (Level, TyVarInfo) +-- | Explicit type parameters. +type TyParams = M.Map TyVar (Level, Loc) + data TyVarSol = -- | Has been substituted with this. TyVarSol Level Type + | -- | Is an explicit type parameter in the source program. + TyVarParam Level Loc | -- | Not substituted yet; has this constraint. TyVarUnsol Level TyVarInfo deriving (Show) tyVarSolLevel :: TyVarSol -> Level tyVarSolLevel (TyVarSol lvl _) = lvl +tyVarSolLevel (TyVarParam lvl _) = lvl tyVarSolLevel (TyVarUnsol lvl _) = lvl newtype SolverState = SolverState @@ -142,14 +151,18 @@ newtype SolverState = SolverState solverTyVars :: M.Map TyVar (Either VName TyVarSol) } -initialState :: TyVars -> SolverState -initialState tyvars = SolverState $ M.map (Right . uncurry TyVarUnsol) tyvars +initialState :: TyParams -> TyVars -> SolverState +initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars + where + f (lvl, info) = Right $ TyVarUnsol lvl info + g (lvl, loc) = Right $ TyVarParam lvl loc substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = case M.lookup v m of Just (Left v') -> substTyVar m v' Just (Right (TyVarSol _ t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right TyVarParam {}) -> Nothing Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing @@ -160,6 +173,8 @@ lookupTyVar orig = do Nothing -> error $ "Unknown tyvar: " <> prettyNameString v Just (Left v') -> f v' Just (Right (TyVarSol lvl t)) -> pure (lvl, Right t) + Just (Right (TyVarParam lvl _)) -> + pure (lvl, Right $ Scalar $ TypeVar mempty (qualName orig) []) Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig @@ -303,6 +318,8 @@ subTyVar reason v lvl t = do -- Internal error cases (Just (Right TyVarSol {}), _) -> error $ "Type variable already solved: " <> prettyNameString v + (Just (Right TyVarParam {}), _) -> + error $ "Cannot substitute type parameter: " <> prettyNameString v (Just Left {}, _) -> error $ "Type variable already linked: " <> prettyNameString v (Nothing, _) -> @@ -430,6 +447,8 @@ linkTyVar reason v t = do -- Internal error cases (TyVarSol {}, _) -> alreadySolved + (TyVarParam {}, _) -> + isParam (_, Right t'') -> error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' @@ -439,6 +458,7 @@ linkTyVar reason v t = do unknown = error $ "linkTyVar: Nothing v: " <> prettyNameString v alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v alreadySolved = error $ "Type variable already solved: " <> prettyNameString v + isParam = error $ "Type name is a type parameter: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Maybe [(Type, Type)] @@ -490,7 +510,8 @@ solveEq reason orig_t1 orig_t2 = do let flexible v = case M.lookup v tyvars of Just (Left v') -> flexible v' Just (Right (TyVarUnsol lvl _)) -> Just lvl - Just (Right (TyVarSol _ _)) -> Nothing + Just (Right TyVarSol {}) -> Nothing + Just (Right TyVarParam {}) -> Nothing Nothing -> Nothing sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of @@ -552,11 +573,11 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do solveTyVar (_, _) = pure () -solve :: Constraints -> TyVars -> Either TypeError ([VName], Solution) -solve constraints tyvars = +solve :: Constraints -> TyParams -> TyVars -> Either TypeError ([VName], Solution) +solve constraints typarams tyvars = second solution . runExcept - . flip execStateT (initialState tyvars) + . flip execStateT (initialState typarams tyvars) . runSolveM $ do mapM_ solveCt constraints diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 56f36c7b9e..cb3082908a 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -112,6 +112,7 @@ data TermEnv = TermEnv data TermState = TermState { termConstraints :: Constraints, termTyVars :: TyVars, + termTyParams :: TyParams, termCounter :: !Int, termWarnings :: Warnings, termNameSource :: VNameSource, @@ -192,6 +193,7 @@ runTermM (TermM m) = do TermState { termConstraints = mempty, termTyVars = mempty, + termTyParams = mempty, termWarnings = mempty, termNameSource = src, termCounter = 0, @@ -632,14 +634,25 @@ bindTypes tbinds = localScope extend } bindTypeParams :: [TypeParam] -> TermM a -> TermM a -bindTypeParams tparams = - bind (mapMaybe typeParamIdent tparams) - . bindTypes (mapMaybe typeParamType tparams) +bindTypeParams tparams m = + bind idents . bindTypes types $ do + lvl <- curLevel + modify $ \s -> + s + { termTyParams = + termTyParams s + <> M.fromList (mapMaybe (typeParam lvl) tparams) + } + m where + idents = mapMaybe typeParamIdent tparams + types = mapMaybe typeParamType tparams typeParamType (TypeParamType l v _) = Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) typeParamType TypeParamDim {} = Nothing + typeParam lvl (TypeParamType _ v loc) = Just (v, (lvl, locOf loc)) + typeParam _ _ = Nothing bindParams :: [TypeParam] -> @@ -1311,6 +1324,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do cts <- gets termConstraints tyvars <- gets termTyVars + typarams <- gets termTyParams artificial <- gets termArtificial debugTraceM 3 $ "\n# function " <> prettyNameString fname <> "\n# " <> locStr loc <> "\n" @@ -1327,15 +1341,15 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) ] - onRankSolution retdecl' + onRankSolution retdecl' typarams =<< rankAnalysis1 loc cts tyvars artificial params' body' where - onRankSolution retdecl' ((cts', artificial, tyvars'), params', body'') = do + onRankSolution retdecl' typarams ((cts', artificial, tyvars'), params', body'') = do solution <- bitraverse pure (fmap (second (onArtificial artificial)) . onTySolution params' body'') - $ solve cts' tyvars' + $ solve cts' typarams tyvars' debugTraceM 3 $ unlines [ "## constraints:", @@ -1367,10 +1381,11 @@ checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints tyvars <- gets termTyVars + typarams <- gets termTyParams artificial <- gets termArtificial ((cts', _artificial', tyvars'), _, e'') <- rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' - case solve cts' tyvars' of + case solve cts' typarams tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do e_t <- expType e'' @@ -1386,13 +1401,14 @@ checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints tyvars <- gets termTyVars + typarams <- gets termTyParams artificial <- gets termArtificial (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' solutions <- forM cts_tyvars' $ \(cts', _artificial', tyvars') -> - bitraverse pure (traverse (doDefaults mempty)) $ solve cts' tyvars' + bitraverse pure (traverse (doDefaults mempty)) $ solve cts' typarams tyvars' case (solutions, es') of ([solution], [e'']) -> From 1ace2df74e8032d8b712b07ade141d71bdd7df1c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 14:44:49 +0200 Subject: [PATCH 222/296] A bit more work. --- .../Futhark/TypeChecker/Constraints.hs | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 3eb1f9d010..af5a5f2cfd 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -12,9 +12,6 @@ module Language.Futhark.TypeChecker.Constraints TyParams, Solution, solve, - -- To hide warnings - tyVarSolLevel, - scopeViolation, ) where @@ -135,17 +132,12 @@ type TyParams = M.Map TyVar (Level, Loc) data TyVarSol = -- | Has been substituted with this. TyVarSol Level Type - | -- | Is an explicit type parameter in the source program. + | -- | Is an explicit (rigid) type parameter in the source program. TyVarParam Level Loc | -- | Not substituted yet; has this constraint. TyVarUnsol Level TyVarInfo deriving (Show) -tyVarSolLevel :: TyVarSol -> Level -tyVarSolLevel (TyVarSol lvl _) = lvl -tyVarSolLevel (TyVarParam lvl _) = lvl -tyVarSolLevel (TyVarUnsol lvl _) = lvl - newtype SolverState = SolverState { -- | Left means linked to this other type variable. solverTyVars :: M.Map TyVar (Either VName TyVarSol) @@ -178,6 +170,12 @@ lookupTyVar orig = do Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand -- sides that contain a type variable. @@ -243,23 +241,38 @@ unifySharedFields reason fs1 fs2 = mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () -scopeViolation :: Reason -> VName -> Type -> SolveM a -scopeViolation reason v tp = +scopeViolation :: Reason -> VName -> Type -> VName -> SolveM a +scopeViolation reason v1 ty v2 = throwError . TypeError (locOf reason) mempty $ "Cannot unify type" - indent 2 (pretty tp) + indent 2 (pretty ty) "with" - <+> dquotes (prettyName v) + <+> dquotes (prettyName v1) <+> "(scope violation)." "This is because" - <+> dquotes (prettyName v) + <+> dquotes (prettyName v2) <+> "is rigidly bound in a deeper scope." +scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = do + mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _)) + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + Just (Right (TyVarUnsol ty_v_lvl info)) + | ty_v_lvl /= v_lvl -> + setInfo ty_v $ TyVarUnsol v_lvl info + _ -> pure () + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () -subTyVar reason v lvl t = do +subTyVar reason v v_lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars + scopeCheck reason v v_lvl t case (v_info, t) of (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () @@ -323,23 +336,15 @@ subTyVar reason v lvl t = do (Just Left {}, _) -> error $ "Type variable already linked: " <> prettyNameString v (Nothing, _) -> - error $ "linkTyVar: Nothing v: " <> prettyNameString v - - setInfo v (TyVarSol lvl t) + error $ "subTyVar: Nothing v: " <> prettyNameString v -setLink :: TyVar -> VName -> SolveM () -setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} - -setInfo :: TyVar -> TyVarSol -> SolveM () -setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + setInfo v (TyVarSol v_lvl t) --- Precondition: 'v' is currently flexible and 't' has no solution. -linkTyVar :: Reason -> VName -> VName -> SolveM () -linkTyVar reason v t = do +-- Precondition: 'v' and 't' are both currently flexible. +unionTyVars :: Reason -> VName -> VName -> SolveM () +unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars (lvl, t') <- lookupTyVar t - let tp = Scalar $ TypeVar NoUniqueness (qualName t) [] - occursCheck reason v tp case (v_info, t') of -- When either is completely unconstrained. @@ -450,12 +455,12 @@ linkTyVar reason v t = do (TyVarParam {}, _) -> isParam (_, Right t'') -> - error $ "linkTyVar: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' + error $ "unionTyVars: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' -- Finally insert the actual link. setLink v t where - unknown = error $ "linkTyVar: Nothing v: " <> prettyNameString v + unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v alreadySolved = error $ "Type variable already solved: " <> prettyNameString v isParam = error $ "Type name is a type parameter: " <> prettyNameString v @@ -530,8 +535,8 @@ solveEq reason orig_t1 orig_t2 = do (Just lvl, Nothing) -> subTyVar reason v1 lvl t2' (Nothing, Just lvl) -> subTyVar reason v2 lvl t1' (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> linkTyVar reason v1 v2 - | otherwise -> linkTyVar reason v2 v1 + | lvl1 <= lvl2 -> unionTyVars reason v1 v2 + | otherwise -> unionTyVars reason v2 v1 (Scalar (TypeVar _ (QualName [] v1) []), t2') | Just lvl <- flexible v1 -> subTyVar reason v1 lvl t2' From 5522eb85c0aa1c42e1b38a10d793ee7004a655bf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:13:30 +0200 Subject: [PATCH 223/296] Check for equality. --- src/Language/Futhark/TypeChecker/Constraints.hs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index af5a5f2cfd..b9a59e90bb 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -575,6 +575,17 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () +solveTyVar (tv, (_, TyVarEql loc)) = do + (_, tv_t) <- lookupTyVar tv + case tv_t of + Left _ -> pure () + Right ty + | orderZero ty -> pure () + | otherwise -> + throwError . TypeError loc mempty $ + "Type" + indent 2 (align (pretty ty)) + "does not support equality (may contain function)." solveTyVar (_, _) = pure () From 43744e5e4f6c0e47bf4faa2135e1149cb5454ff1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:19:21 +0200 Subject: [PATCH 224/296] Detect ambiguous equality type. --- src/Language/Futhark/TypeChecker/Constraints.hs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b9a59e90bb..be19eccec9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -578,7 +578,10 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do solveTyVar (tv, (_, TyVarEql loc)) = do (_, tv_t) <- lookupTyVar tv case tv_t of - Left _ -> pure () + Left _ -> + throwError . TypeError loc mempty $ + "Type is ambiguous (must be equality type)" + "Add a type annotation to disambiguate the type." Right ty | orderZero ty -> pure () | otherwise -> From 7329e8bbf7400b0ba7dbe5f4301d4fc1e2814ee3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:22:37 +0200 Subject: [PATCH 225/296] Abstraction. --- .../Futhark/TypeChecker/Constraints.hs | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index be19eccec9..c96fa9f442 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -26,7 +26,7 @@ import Data.Maybe import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark -import Language.Futhark.TypeChecker.Monad (TypeError (..)) +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..)) import Language.Futhark.TypeChecker.Types (substTyVars) -- | The reason for a type constraint. Used to generate type error @@ -201,11 +201,15 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) +typeError :: Loc -> Notes -> Doc () -> SolveM () +typeError loc notes msg = + throwError $ TypeError loc notes msg + occursCheck :: Reason -> VName -> Type -> SolveM () occursCheck reason v tp = do vars <- gets solverTyVars let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . throwError . TypeError (locOf reason) mempty $ + when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ "Occurs check: cannot instantiate" <+> prettyName v <+> "with" @@ -220,9 +224,9 @@ unifySharedConstructors :: unifySharedConstructors reason cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> if length ts1 == length ts2 - then zipWithM (solveEq reason) ts1 ts2 + then zipWithM_ (solveEq reason) ts1 ts2 else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructor" indent 2 (pretty (Sum (M.singleton c ts1))) "with type of constructor" @@ -241,9 +245,9 @@ unifySharedFields reason fs1 fs2 = mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () -scopeViolation :: Reason -> VName -> Type -> VName -> SolveM a +scopeViolation :: Reason -> VName -> Type -> VName -> SolveM () scopeViolation reason v1 ty v2 = - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type" indent 2 (pretty ty) "with" @@ -282,7 +286,7 @@ subTyVar reason v v_lvl t = do if t `elem` map (Scalar . Prim) v_pts then pure () else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with" @@ -293,7 +297,7 @@ subTyVar reason v v_lvl t = do if all (`elem` M.keys cs2) (M.keys cs1) then unifySharedConstructors reason cs1 cs2 else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructors" indent 2 (pretty (Sum cs1)) "with type" @@ -301,7 +305,7 @@ subTyVar reason v v_lvl t = do ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), _ ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructors" indent 2 (pretty (Sum cs1)) "with type" @@ -312,7 +316,7 @@ subTyVar reason v v_lvl t = do if all (`elem` M.keys fs2) (M.keys fs1) then unifySharedFields reason fs1 fs2 else - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify record type with fields" indent 2 (pretty (Record fs1)) "with record type" @@ -320,7 +324,7 @@ subTyVar reason v v_lvl t = do ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), _ ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify record type with fields" indent 2 (pretty (Record fs1)) "with type" @@ -366,7 +370,7 @@ unionTyVars reason v t = do let pts = L.intersect v_pts t_pts in if null pts then - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be one of" @@ -375,14 +379,14 @@ unionTyVars reason v t = do ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarRecord {} ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarSum {} ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be sum." @@ -397,13 +401,13 @@ unionTyVars reason v t = do ( TyVarUnsol _ TyVarSum {}, Left (TyVarPrim _ pts) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarSum _ cs1), Left (TyVarRecord _ fs) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify type with constructors" indent 2 (pretty (Sum cs1)) "with type" @@ -423,13 +427,13 @@ unionTyVars reason v t = do ( TyVarUnsol _ TyVarRecord {}, Left (TyVarPrim _ pts) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarRecord _ fs1), Left (TyVarSum _ cs) ) -> - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify record type" indent 2 (pretty (Record fs1)) "with type" @@ -504,7 +508,7 @@ solveEq reason orig_t1 orig_t2 = do where cannotUnify = do tyvars <- gets solverTyVars - throwError . TypeError (locOf reason) mempty $ + typeError (locOf reason) mempty $ "Cannot unify" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) "with" @@ -558,7 +562,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do (_, tv_t) <- lookupTyVar tv case tv_t of Left _ -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type" <+> prettyName tv <+> "is ambiguous." @@ -570,7 +574,7 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do (_, tv_t) <- lookupTyVar tv case tv_t of Left _ -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type is ambiguous." "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) @@ -579,13 +583,13 @@ solveTyVar (tv, (_, TyVarEql loc)) = do (_, tv_t) <- lookupTyVar tv case tv_t of Left _ -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type is ambiguous (must be equality type)" "Add a type annotation to disambiguate the type." Right ty | orderZero ty -> pure () | otherwise -> - throwError . TypeError loc mempty $ + typeError loc mempty $ "Type" indent 2 (align (pretty ty)) "does not support equality (may contain function)." From 8a99272d1c1e4d1ca76062ee1e0accf1f6d40f32 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 18:24:39 +0200 Subject: [PATCH 226/296] Refine. --- src/Language/Futhark/TypeChecker/Constraints.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index c96fa9f442..b85799cc1f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -582,10 +582,11 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do solveTyVar (tv, (_, TyVarEql loc)) = do (_, tv_t) <- lookupTyVar tv case tv_t of - Left _ -> + Left TyVarEql {} -> typeError loc mempty $ "Type is ambiguous (must be equality type)" "Add a type annotation to disambiguate the type." + Left _ -> pure () Right ty | orderZero ty -> pure () | otherwise -> From f7892a191d36a459d616de568ef280a8a7a96e8e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 20:35:01 +0200 Subject: [PATCH 227/296] More fixes. --- .../Futhark/TypeChecker/Constraints.hs | 78 +++++++++------- src/Language/Futhark/TypeChecker/Rank.hs | 7 +- src/Language/Futhark/TypeChecker/Terms2.hs | 89 +++++++++++-------- 3 files changed, 102 insertions(+), 72 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b85799cc1f..b7bd2f5d81 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -11,6 +11,7 @@ module Language.Futhark.TypeChecker.Constraints TyVars, TyParams, Solution, + UnconTyVar, solve, ) where @@ -89,7 +90,7 @@ type Constraints = [Ct] -- that it is the type of. data TyVarInfo = -- | Can be substituted with anything. - TyVarFree Loc + TyVarFree Loc Liftedness | -- | Can only be substituted with these primitive types. TyVarPrim Loc [PrimType] | -- | Must be a record with these fields. @@ -101,14 +102,14 @@ data TyVarInfo deriving (Show, Eq) instance Pretty TyVarInfo where - pretty (TyVarFree _) = "free" + pretty (TyVarFree _ l) = "free" <+> pretty l pretty (TyVarPrim _ pts) = "∈" <+> pretty pts pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs pretty (TyVarEql _) = "equality" instance Located TyVarInfo where - locOf (TyVarFree loc) = loc + locOf (TyVarFree loc _) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc @@ -158,7 +159,7 @@ substTyVar m v = Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing -lookupTyVar :: TyVar -> SolveM (Int, Either TyVarInfo Type) +lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) lookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of @@ -181,7 +182,11 @@ setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solv -- sides that contain a type variable. type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -solution :: SolverState -> ([VName], Solution) +-- | An unconstrained type variable comprises a name and (ironically) +-- a constraint on how it can be instantiated. +type UnconTyVar = (VName, Liftedness) + +solution :: SolverState -> ([UnconTyVar], Solution) solution s = ( mapMaybe unconstrained $ M.toList $ solverTyVars s, M.mapMaybe mkSubst $ solverTyVars s @@ -195,7 +200,7 @@ solution s = mkSubst (Right (TyVarUnsol _ (TyVarPrim _ pts))) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, Right (TyVarUnsol _ (TyVarFree _))) = Just v + unconstrained (v, Right (TyVarUnsol _ (TyVarFree _ l))) = Just (v, l) unconstrained _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} @@ -257,26 +262,11 @@ scopeViolation reason v1 ty v2 = <+> dquotes (prettyName v2) <+> "is rigidly bound in a deeper scope." -scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () -scopeCheck reason v v_lvl ty = do - mapM_ check $ typeVars ty - where - check ty_v = do - ty_v_info <- gets $ M.lookup ty_v . solverTyVars - case ty_v_info of - Just (Right (TyVarParam ty_v_lvl _)) - | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v - Just (Right (TyVarUnsol ty_v_lvl info)) - | ty_v_lvl /= v_lvl -> - setInfo ty_v $ TyVarUnsol v_lvl info - _ -> pure () - -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v v_lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars - scopeCheck reason v v_lvl t case (v_info, t) of (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () @@ -348,22 +338,27 @@ subTyVar reason v v_lvl t = do unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - (lvl, t') <- lookupTyVar t + (t_lvl, t') <- lookupTyVar t case (v_info, t') of + ( TyVarUnsol _ (TyVarFree _ v_l), + Left (TyVarFree t_loc t_l) + ) + | v_l /= t_l -> + setInfo t $ TyVarUnsol t_lvl $ TyVarFree t_loc (min v_l t_l) -- When either is completely unconstrained. (TyVarUnsol _ TyVarFree {}, _) -> pure () ( TyVarUnsol _ info, Left (TyVarFree {}) ) -> - setInfo t (TyVarUnsol lvl info) + setInfo t (TyVarUnsol t_lvl info) -- -- TyVarPrim cases ( TyVarUnsol _ info@TyVarPrim {}, Left TyVarEql {} ) -> - setInfo t (TyVarUnsol lvl info) + setInfo t (TyVarUnsol t_lvl info) ( TyVarUnsol _ (TyVarPrim _ v_pts), Left (TyVarPrim t_loc t_pts) ) -> @@ -375,7 +370,7 @@ unionTyVars reason v t = do indent 2 (pretty v_pts) "with type that must be one of" indent 2 (pretty t_pts) - else setInfo t (TyVarUnsol lvl (TyVarPrim t_loc pts)) + else setInfo t (TyVarUnsol t_lvl (TyVarPrim t_loc pts)) ( TyVarUnsol _ (TyVarPrim _ v_pts), Left TyVarRecord {} ) -> @@ -397,7 +392,7 @@ unionTyVars reason v t = do ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 - setInfo t (TyVarUnsol lvl (TyVarSum loc cs3)) + setInfo t (TyVarUnsol t_lvl (TyVarSum loc cs3)) ( TyVarUnsol _ TyVarSum {}, Left (TyVarPrim _ pts) ) -> @@ -423,7 +418,7 @@ unionTyVars reason v t = do ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 - setInfo t (TyVarUnsol lvl (TyVarRecord loc fs3)) + setInfo t (TyVarUnsol t_lvl (TyVarRecord loc fs3)) ( TyVarUnsol _ TyVarRecord {}, Left (TyVarPrim _ pts) ) -> @@ -557,7 +552,18 @@ solveCt ct = CtEq reason t1 t2 -> solveEq reason t1 t2 CtAM {} -> pure () -- Good vibes only. -solveTyVar :: (VName, (Int, TyVarInfo)) -> SolveM () +scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = do + mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _)) + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + _ -> pure () + +solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do (_, tv_t) <- lookupTyVar tv case tv_t of @@ -594,10 +600,18 @@ solveTyVar (tv, (_, TyVarEql loc)) = do "Type" indent 2 (align (pretty ty)) "does not support equality (may contain function)." -solveTyVar (_, _) = - pure () - -solve :: Constraints -> TyParams -> TyVars -> Either TypeError ([VName], Solution) +solveTyVar (tv, (lvl, _)) = do + (_, tv_t) <- lookupTyVar tv + case tv_t of + Right ty -> + scopeCheck (Reason mempty) tv lvl ty + _ -> pure () + +solve :: + Constraints -> + TyParams -> + TyVars -> + Either TypeError ([UnconTyVar], Solution) solve constraints typarams tyvars = second solution . runExcept diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index a052d44f54..c00ba3b106 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -160,7 +160,7 @@ addCt (CtAM _ r m f) = do addObj tr addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () -addTyVarInfo _ (_, TyVarFree _) = pure () +addTyVarInfo _ (_, TyVarFree {}) = pure () addTyVarInfo tv (_, TyVarPrim {}) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarRecord {}) = @@ -392,8 +392,11 @@ addRankInfo t = do t' <- newTyVar t old_tyvars <- asks envTyVars let (level, tvinfo) = fromJust $ old_tyvars M.!? t + l = case tvinfo of + TyVarFree _ tvinfo_l -> tvinfo_l + _ -> Unlifted modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo) $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree $ locOf tvinfo) $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree (locOf tvinfo) l) $ substTyVars s} class SubstRanks a where substRanks :: (MonadTypeChecker m) => a -> SubstT m a diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cb3082908a..c362e7ea5c 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -231,11 +231,15 @@ newTyVarWith desc info = do modify $ \s -> s {termTyVars = M.insert v (lvl, info) $ termTyVars s} pure v -newTyVar :: (Located loc) => loc -> Name -> TermM TyVar -newTyVar loc desc = newTyVarWith desc $ TyVarFree $ locOf loc +newTyVar :: (Located loc) => loc -> Liftedness -> Name -> TermM TyVar +newTyVar loc l desc = newTyVarWith desc $ TyVarFree (locOf loc) l -newType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) -newType loc desc u = tyVarType u <$> newTyVar loc desc +newType :: (Located loc) => loc -> Liftedness -> Name -> u -> TermM (TypeBase dim u) +newType loc l desc u = tyVarType u <$> newTyVar loc l desc + +-- | New type that must be allowed as an array element. +newElemType :: (Located loc) => loc -> Name -> u -> TermM (TypeBase dim u) +newElemType loc desc u = tyVarType u <$> newTyVar loc Unlifted desc newTypeWithField :: SrcLoc -> Name -> Name -> Type -> TermM Type newTypeWithField loc desc k t = @@ -404,8 +408,8 @@ instTypeScheme _qn loc tparams t = do (names, substs) <- fmap (unzip . catMaybes) $ forM tparams $ \tparam -> case tparam of - TypeParamType _ v _ -> do - v' <- newTyVar loc $ nameFromString $ takeWhile isAscii $ baseString v + TypeParamType l v _ -> do + v' <- newTyVar loc l $ nameFromString $ takeWhile isAscii $ baseString v pure $ Just (v, (typeParamName tparam, tyVarType NoUniqueness v')) TypeParamDim {} -> pure Nothing @@ -490,13 +494,13 @@ checkPat' (Id name NoInfo loc) (Ascribed t) = do t' <- asStructType t pure $ Id name (Info t') loc checkPat' (Id name NoInfo loc) NoneInferred = do - t <- newType loc "t" Observe + t <- newType loc Lifted "t" Observe pure $ Id name (Info t) loc checkPat' (Wildcard _ loc) (Ascribed t) = do t' <- asStructType t pure $ Wildcard (Info t') loc checkPat' (Wildcard NoInfo loc) NoneInferred = do - t <- newType loc "t" Observe + t <- newType loc Lifted "t" Observe pure $ Wildcard (Info t) loc checkPat' (TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, @@ -505,7 +509,7 @@ checkPat' (TuplePat ps loc) (Ascribed t) <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc | otherwise = do - ps_tvs <- replicateM (length ps) (newTyVar loc "t") + ps_tvs <- replicateM (length ps) (newTyVar loc Lifted "t") ctEq (Reason (locOf loc)) (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) t TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = @@ -525,7 +529,9 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) L.sort (map fst p_fs) == L.sort (M.keys t_fs) = RecordPat . M.toList <$> check t_fs <*> pure loc | otherwise = do - p_fs' <- traverse (const $ newType loc "t" NoUniqueness) $ M.fromList p_fs + p_fs' <- + traverse (const $ newType loc Lifted "t" NoUniqueness) $ + M.fromList p_fs ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t checkPat' p $ Ascribed $ Observe <$ Scalar (Record p_fs') where @@ -578,7 +584,7 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do - p_t <- newType (srclocOf p) "t" Observe + p_t <- newType (srclocOf p) Lifted "t" Observe checkPat' p $ Ascribed p_t t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' ctEq (Reason (locOf loc)) t' t @@ -730,8 +736,8 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do (a, b) <- split $ Scalar t pure (arrayOf s a, arrayOf s b) split ftype' = do - a <- newType loc "arg" NoUniqueness - b <- newType loc "res" Nonunique + a <- newType loc Lifted "arg" NoUniqueness + b <- newType loc Lifted "res" Nonunique ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) @@ -759,7 +765,7 @@ mustHaveFields loc t [f] ve_t = do rt :: Type <- newTypeWithField loc "ft" f ve_t ctEq (Reason (locOf loc)) t rt mustHaveFields loc t (f : fs) ve_t = do - ft <- newType loc "ft" NoUniqueness + ft <- newType loc Lifted "ft" NoUniqueness rt <- newTypeWithField loc "rt" f ft mustHaveFields loc ft fs ve_t ctEq (Reason (locOf loc)) t rt @@ -844,7 +850,7 @@ checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc checkExp (Hole NoInfo loc) = - Hole <$> (Info <$> newType loc "hole" NoUniqueness) <*> pure loc + Hole <$> (Info <$> newType loc Lifted "hole" NoUniqueness) <*> pure loc checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (TupLit es loc) = @@ -878,7 +884,7 @@ checkExp (ArrayLit es _ loc) = do -- type variables for pathologically large arrays with -- type-unsuffixed integers. Add some special case that handles that -- more efficiently. - et <- newType loc "et" NoUniqueness + et <- newElemType loc "et" NoUniqueness es' <- forM es $ \e -> do e' <- checkExp e e_t <- expType e' @@ -974,7 +980,7 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do optype <- lookupVar loc op e' <- checkExp e e_t <- expType e' - t2 <- newType loc "t" NoUniqueness + t2 <- newType loc Lifted "t" NoUniqueness t2' <- asStructType t2 let f1 = frameOf e' (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((f1, e_t) NE.:| [(mempty, t2)]) @@ -997,7 +1003,7 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do optype <- lookupVar loc op e' <- checkExp e e_t <- expType e' - t1 <- newType loc "t" NoUniqueness + t1 <- newType loc Lifted "t" NoUniqueness t1' <- asStructType t1 let f2 = frameOf e' (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((mempty, t1) NE.:| [(f2, e_t)]) @@ -1019,8 +1025,8 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do loc -- checkExp (ProjectSection fields NoInfo loc) = do - a <- newType loc "a" NoUniqueness - b <- newType loc "b" NoUniqueness + a <- newType loc Lifted "a" NoUniqueness + b <- newType loc Lifted "b" NoUniqueness mustHaveFields loc a fields b ft <- asStructType $ Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ b `setUniqueness` Nonunique pure $ ProjectSection fields (Info ft) loc @@ -1087,7 +1093,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do pure e' maybe_step' <- traverse check maybe_step end' <- traverse check end - range_t <- newType loc "range" NoUniqueness + range_t <- newElemType loc "range" NoUniqueness range_t' <- asType range_t start_t <- expType start' ctEq (Reason (locOf start')) range_t' (arrayOfRank 1 start_t) @@ -1095,7 +1101,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do -- checkExp (Project k e NoInfo loc) = do e' <- checkExp e - kt <- newType loc "kt" NoUniqueness + kt <- newType loc Lifted "kt" NoUniqueness t <- newTypeWithField loc "t" k kt e_t <- expType e' ctEq (Reason (locOf e')) e_t t @@ -1113,9 +1119,9 @@ checkExp (RecordUpdate src fields ve NoInfo loc) = do -- checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice - index_arg_t <- newType loc "index" NoUniqueness - index_elem_t <- newType loc "index_elem" NoUniqueness - index_res_t <- newType loc "index_res" NoUniqueness + index_arg_t <- newElemType loc "index" NoUniqueness + index_elem_t <- newElemType loc "index_elem" NoUniqueness + index_res_t <- newElemType loc "index_res" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (Reason (locOf loc)) index_arg_t $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf loc)) index_res_t $ arrayOfRank (length slice) index_elem_t @@ -1126,8 +1132,8 @@ checkExp (AppExp (Index e slice loc) _) = do e' <- checkExp e e_t <- expType e' slice' <- checkSlice slice - index_tv <- newTyVar loc "index" - index_elem_t <- newType loc "index_elem" NoUniqueness + index_tv <- newTyVar loc Unlifted "index" + index_elem_t <- newElemType loc "index_elem" NoUniqueness let num_slices = length $ filter isSlice slice ctEq (Reason (locOf loc)) (tyVarType NoUniqueness index_tv) $ arrayOfRank num_slices index_elem_t ctEq (Reason (locOf e')) e_t $ arrayOfRank (length slice) index_elem_t @@ -1140,7 +1146,7 @@ checkExp (Update src slice ve loc) = do ve' <- checkExp ve ve_t <- expType ve' let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" NoUniqueness + update_elem_t <- newElemType loc "update_elem" NoUniqueness ctEq (Reason (locOf src')) src_t $ arrayOfRank (length slice) update_elem_t ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t pure $ Update src' slice' ve' loc @@ -1154,7 +1160,7 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = do ve' <- checkExp ve ve_t <- expType ve' let num_slices = length $ filter isSlice slice - update_elem_t <- newType loc "update_elem" NoUniqueness + update_elem_t <- newElemType loc "update_elem" NoUniqueness ctEq (Reason (locOf loc)) src_t $ arrayOfRank (length slice) update_elem_t ctEq (Reason (locOf ve')) ve_t $ arrayOfRank num_slices update_elem_t bind [dest'] $ do @@ -1170,12 +1176,14 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do e2_t <- expType e2' e3' <- checkExp e3 e3_t <- expType e3' + if_t <- newType loc SizeLifted "if_t" NoUniqueness ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) - ctEq (Reason (locOf loc)) e2_t e3_t + ctEq (Reason (locOf loc)) e2_t if_t + ctEq (Reason (locOf loc)) e3_t if_t - e2_t' <- asStructType e2_t - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes e2_t' []) + if_t' <- asStructType if_t + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes if_t' []) -- checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e @@ -1205,7 +1213,7 @@ checkExp (AppExp (Loop _ pat arg form body loc) _) = do pure (While cond', body') ForIn elemp arr -> do arr' <- checkExp arr - elem_t <- newType elemp "elem" NoUniqueness + elem_t <- newElemType elemp "elem" NoUniqueness arr_t <- expType arr' elem_t' <- asType elem_t ctEq (Reason (locOf arr')) arr_t $ arrayOfRank 1 elem_t' @@ -1270,13 +1278,16 @@ doDefaults tyvars_at_toplevel substs = do pure $ M.map (substTyVars (`M.lookup` substs')) substs' generalise :: - TypeBase () NoUniqueness -> [VName] -> Solution -> ([TypeParam], [VName]) + TypeBase () NoUniqueness -> + [UnconTyVar] -> + Solution -> + ([TypeParam], [VName]) generalise fun_t unconstrained solution = -- Candidates for let-generalisation are those type variables that -- are used in fun_t. let visible = foldMap expandTyVars $ typeVars fun_t - onTyVar v - | v `S.member` visible = Left $ TypeParamType Unlifted v mempty + onTyVar (v, l) + | v `S.member` visible = Left $ TypeParamType l v mempty | otherwise = Right v in partitionEithers $ map onTyVar unconstrained where @@ -1286,7 +1297,7 @@ generalise fun_t unconstrained solution = _ -> S.singleton v generaliseAndDefaults :: - [VName] -> + [UnconTyVar] -> Solution -> TypeBase () NoUniqueness -> TermM ([TypeParam], M.Map VName (TypeBase () NoUniqueness)) @@ -1354,6 +1365,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines [ "## constraints:", unlines $ map prettyString cts', + "## typarams:", + unlines (map (prettyString . bimap prettyNameString fst) (M.toList typarams)), "## tyvars':", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", @@ -1396,7 +1409,7 @@ checkSingleExp e = runTermM $ do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> - TypeM (Either TypeError ([VName], M.Map TyVar (TypeBase () NoUniqueness)), Exp) + TypeM (Either TypeError ([UnconTyVar], M.Map TyVar (TypeBase () NoUniqueness)), Exp) checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints From fbafb6db1cc4c53033ad1555bbcc354a3e50a2d5 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 20:59:10 +0200 Subject: [PATCH 228/296] Define before use. --- src/Language/Futhark/TypeChecker/Terms2.hs | 34 +++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index c362e7ea5c..d231cf7263 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -673,23 +673,6 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApply :: - SrcLoc -> - Maybe (QualName VName) -> - (Shape Size, Type) -> - NE.NonEmpty (Shape Size, Type) -> - TermM (Type, NE.NonEmpty AutoMap) -checkApply loc fname (fframe, ftype) args = do - ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args - pure (rt, argts) - where - onArg (i, f_f, f_t) (argframe, argtype) = do - (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) - pure - ( (i + 1, autoFrame am, rt), - am - ) - checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do (a, b) <- split ftype @@ -741,6 +724,23 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) +checkApply :: + SrcLoc -> + Maybe (QualName VName) -> + (Shape Size, Type) -> + NE.NonEmpty (Shape Size, Type) -> + TermM (Type, NE.NonEmpty AutoMap) +checkApply loc fname (fframe, ftype) args = do + ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args + pure (rt, argts) + where + onArg (i, f_f, f_t) (argframe, argtype) = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) + pure + ( (i + 1, autoFrame am, rt), + am + ) + checkSlice :: SliceBase NoInfo VName -> TermM [DimIndex] checkSlice = mapM checkDimIndex where From 4b2a9e2ec47124fb48f718167377dd6b52f4ecaa Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:00:31 +0200 Subject: [PATCH 229/296] Break long line. --- src/Language/Futhark/TypeChecker/Terms2.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index d231cf7263..f35aa83de9 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -707,7 +707,11 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do ] pure ( arrayOf (toShape (SVar m)) b, - AutoMap {autoRep = toShape r_var, autoMap = toShape m_var, autoFrame = toShape m_var <> fframe} + AutoMap + { autoRep = toShape r_var, + autoMap = toShape m_var, + autoFrame = toShape m_var <> fframe + } ) where toSComp (Var (QualName [] x) _ _) = SVar x From d393544492b332d7c187ae863127eeb963c93572 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:08:16 +0200 Subject: [PATCH 230/296] Add missing cases. --- src/Language/Futhark/TypeChecker/Rank.hs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c00ba3b106..6f597ca8f0 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -452,6 +452,22 @@ updAM rank_map e = in AppExp (Apply f' args' loc) res AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res + OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc -> + OpSectionRight + name + t + (updAM rank_map arg) + (Info (pa, t1a), Info (pb, t1b, argext, upd am)) + t2 + loc + OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc -> + OpSectionLeft + name + t + (updAM rank_map arg) + (Info (pa, t1a, argext, upd am), Info (pb, t1b)) + (ret, retext) + loc _ -> runIdentity $ astMap mapper e where dimToRank (Var (QualName [] x) _ _) = From d86eec2eb49e192541cfacd00fb1bc3237bc624f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:22:07 +0200 Subject: [PATCH 231/296] This is OK now. --- tests/record-update6.fut | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/record-update6.fut b/tests/record-update6.fut index 53349ac0ab..fef501e1a2 100644 --- a/tests/record-update6.fut +++ b/tests/record-update6.fut @@ -1,10 +1,9 @@ -- Inference of record in lambda. -- == --- error: Full type of type octnode = {body: i32} -def f (octree: []octnode) (i: i32) = +entry f (octree: []octnode) (i: i32) = map (\n -> if n.body != i then n else n with body = 0) octree From 43c85b30d600178d79b563c7a977b8546e12f24c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jul 2024 21:54:35 +0200 Subject: [PATCH 232/296] Correct handling of type annotation. --- src/Language/Futhark/TypeChecker/Terms.hs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b1dfbf494e..d2f03ecc9c 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -687,13 +687,13 @@ checkExp (Lambda params body rettype_te (Info (RetType _ rt)) loc) = do params'' <- mapM updateTypes params' - (rettype', rettype_st) <- - case rettype_checked of - Just (te, _, ext) -> - pure (Just te, RetType ext rt') - Nothing -> do - RetType ext ret <- inferReturnSizes params'' $ toRes Nonunique body_t - pure (Nothing, RetType ext ret) + (rettype', rettype_st) <- case rettype_checked of + Just (te, ret, ext) -> do + ret' <- normTypeFully ret + pure (Just te, RetType ext ret') + Nothing -> do + ret <- inferReturnSizes params'' $ toRes Nonunique body_t + pure (Nothing, ret) pure (params'', body', rettype', rettype_st) From 44c54056cfad05ab501f239e2b707f661047d4a8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jul 2024 08:55:28 +0200 Subject: [PATCH 233/296] Simplify. --- .../Futhark/TypeChecker/Constraints.hs | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b7bd2f5d81..178c22e2b9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -171,6 +171,14 @@ lookupTyVar orig = do Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) f orig +-- | Variable must be flexible. +lookupTyVarInfo :: TyVar -> SolveM (Level, TyVarInfo) +lookupTyVarInfo v = do + (lvl, r) <- lookupTyVar v + case r of + Left info -> pure (lvl, info) + Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v + setLink :: TyVar -> VName -> SolveM () setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} @@ -338,11 +346,15 @@ subTyVar reason v v_lvl t = do unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - (t_lvl, t') <- lookupTyVar t + (t_lvl, t_info) <- lookupTyVarInfo t + + -- Insert the link from v to t, and then update the info of t based + -- on the existing info of v and t. + setLink v t - case (v_info, t') of + case (v_info, t_info) of ( TyVarUnsol _ (TyVarFree _ v_l), - Left (TyVarFree t_loc t_l) + TyVarFree t_loc t_l ) | v_l /= t_l -> setInfo t $ TyVarUnsol t_lvl $ TyVarFree t_loc (min v_l t_l) @@ -350,17 +362,17 @@ unionTyVars reason v t = do (TyVarUnsol _ TyVarFree {}, _) -> pure () ( TyVarUnsol _ info, - Left (TyVarFree {}) + TyVarFree {} ) -> setInfo t (TyVarUnsol t_lvl info) -- -- TyVarPrim cases ( TyVarUnsol _ info@TyVarPrim {}, - Left TyVarEql {} + TyVarEql {} ) -> setInfo t (TyVarUnsol t_lvl info) ( TyVarUnsol _ (TyVarPrim _ v_pts), - Left (TyVarPrim t_loc t_pts) + TyVarPrim t_loc t_pts ) -> let pts = L.intersect v_pts t_pts in if null pts @@ -372,14 +384,14 @@ unionTyVars reason v t = do indent 2 (pretty t_pts) else setInfo t (TyVarUnsol t_lvl (TyVarPrim t_loc pts)) ( TyVarUnsol _ (TyVarPrim _ v_pts), - Left TyVarRecord {} + TyVarRecord {} ) -> typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." ( TyVarUnsol _ (TyVarPrim _ v_pts), - Left TyVarSum {} + TyVarSum {} ) -> typeError (locOf reason) mempty $ "Cannot unify type that must be one of" @@ -388,19 +400,19 @@ unionTyVars reason v t = do -- -- TyVarSum cases ( TyVarUnsol _ (TyVarSum _ cs1), - Left (TyVarSum loc cs2) + TyVarSum loc cs2 ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 setInfo t (TyVarUnsol t_lvl (TyVarSum loc cs3)) ( TyVarUnsol _ TyVarSum {}, - Left (TyVarPrim _ pts) + TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarSum _ cs1), - Left (TyVarRecord _ fs) + TyVarRecord _ fs ) -> typeError (locOf reason) mempty $ "Cannot unify type with constructors" @@ -408,25 +420,25 @@ unionTyVars reason v t = do "with type" indent 2 (pretty (Scalar (Record fs))) ( TyVarUnsol _ (TyVarSum _ cs1), - Left (TyVarEql _) + TyVarEql _ ) -> mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases ( TyVarUnsol _ (TyVarRecord _ fs1), - Left (TyVarRecord loc fs2) + TyVarRecord loc fs2 ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 setInfo t (TyVarUnsol t_lvl (TyVarRecord loc fs3)) ( TyVarUnsol _ TyVarRecord {}, - Left (TyVarPrim _ pts) + TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) ( TyVarUnsol _ (TyVarRecord _ fs1), - Left (TyVarSum _ cs) + TyVarSum _ cs ) -> typeError (locOf reason) mempty $ "Cannot unify record type" @@ -434,18 +446,18 @@ unionTyVars reason v t = do "with type" indent 2 (pretty (Scalar (Sum cs))) ( TyVarUnsol _ (TyVarRecord _ fs1), - Left (TyVarEql _) + TyVarEql _ ) -> mapM_ (mustSupportEql reason) fs1 -- -- TyVarEql cases - (TyVarUnsol _ (TyVarEql _), Left TyVarPrim {}) -> + (TyVarUnsol _ (TyVarEql _), TyVarPrim {}) -> pure () - (TyVarUnsol _ (TyVarEql _), Left TyVarEql {}) -> + (TyVarUnsol _ (TyVarEql _), TyVarEql {}) -> pure () - (TyVarUnsol _ (TyVarEql _), Left (TyVarRecord _ fs)) -> + (TyVarUnsol _ (TyVarEql _), TyVarRecord _ fs) -> mustSupportEql reason $ Scalar $ Record fs - (TyVarUnsol _ (TyVarEql _), Left (TyVarSum _ cs)) -> + (TyVarUnsol _ (TyVarEql _), TyVarSum _ cs) -> mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases @@ -453,11 +465,6 @@ unionTyVars reason v t = do alreadySolved (TyVarParam {}, _) -> isParam - (_, Right t'') -> - error $ "unionTyVars: rhs " <> prettyNameString t <> " is solved as " <> prettyString t'' - - -- Finally insert the actual link. - setLink v t where unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v From 2333b7e1a5e13fadc3ee5e03573996c08ef009d1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jul 2024 22:38:31 +0200 Subject: [PATCH 234/296] Also AUTOMAP in return type annotations. --- src/Language/Futhark/TypeChecker/Rank.hs | 62 ++++++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 13 ++--- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 6f597ca8f0..24254d7392 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -256,13 +256,19 @@ rankAnalysis1 :: M.Map TyVar Type -> [Pat ParamType] -> Exp -> - m (([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp) -rankAnalysis1 loc cs tyVars artificial params body = do - solutions <- rankAnalysis loc cs tyVars artificial params body + Maybe (TypeExp Exp VName) -> + m + ( ([Ct], M.Map TyVar Type, TyVars), + [Pat ParamType], + Exp, + Maybe (TypeExp Exp VName) + ) +rankAnalysis1 loc cs tyVars artificial params body retdecl = do + solutions <- rankAnalysis loc cs tyVars artificial params body retdecl case solutions of [sol] -> pure sol sols -> do - let (_, _, bodies') = unzip3 sols + let (_, _, bodies', _) = L.unzip4 sols typeError loc mempty $ stack $ [ "Rank ILP is ambiguous.", @@ -278,10 +284,17 @@ rankAnalysis :: M.Map TyVar Type -> [Pat ParamType] -> Exp -> - m [(([Ct], M.Map TyVar Type, TyVars), [Pat ParamType], Exp)] -rankAnalysis _ [] tyVars artificial params body = - pure [(([], artificial, tyVars), params, body)] -rankAnalysis loc cs tyVars artificial params body = do + Maybe (TypeExp Exp VName) -> + m + [ ( ([Ct], M.Map TyVar Type, TyVars), + [Pat ParamType], + Exp, + Maybe (TypeExp Exp VName) + ) + ] +rankAnalysis _ [] tyVars artificial params body retdecl = + pure [(([], artificial, tyVars), params, body, retdecl)] +rankAnalysis loc cs tyVars artificial params body retdecl = do debugTraceM 3 $ unlines [ "##rankAnalysis", @@ -294,18 +307,21 @@ rankAnalysis loc cs tyVars artificial params body = do cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps - pure $ zip3 cts_tyvars' params' bodys + retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps + pure $ L.zip4 cts_tyvars' params' bodys retdecls where cs' = foldMap distribAndSplitCnstrs $ foldMap distribAndSplitArrows cs +type RankMap = M.Map VName Int + substRankInfo :: (MonadTypeChecker m) => [Ct] -> M.Map VName Type -> TyVars -> - Map VName Int -> + RankMap -> m ([Ct], M.Map VName Type, TyVars) substRankInfo cs artificial tyVars rankmap = do ((cs', artificial', tyVars'), new_cs, new_tyVars) <- @@ -316,7 +332,7 @@ substRankInfo cs artificial tyVars rankmap = do isCtAM (CtAM {}) = True isCtAM _ = False -runSubstT :: (MonadTypeChecker m) => TyVars -> Map VName Int -> SubstT m a -> m (a, [Ct], TyVars) +runSubstT :: (MonadTypeChecker m) => TyVars -> RankMap -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do let env = SubstEnv @@ -344,7 +360,7 @@ newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) data SubstEnv = SubstEnv { envTyVars :: TyVars, - envRanks :: Map VName Int + envRanks :: RankMap } data SubstState = SubstState @@ -443,7 +459,7 @@ instance SubstRanks TyVarInfo where instance SubstRanks (Int, TyVarInfo) where substRanks (lvl, tv) = (lvl,) <$> substRanks tv -updAM :: Map VName Int -> Exp -> Exp +updAM :: RankMap -> Exp -> Exp updAM rank_map e = case e of AppExp (Apply f args loc) res -> @@ -476,15 +492,17 @@ updAM rank_map e = shapeToRank = Shape . foldMap dimToRank upd (AutoMap r m f) = AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) - mapper = - identityMapper - { mapOnExp = pure . updAM rank_map - } + mapper = identityMapper {mapOnExp = pure . updAM rank_map} -updAMPat :: M.Map VName Int -> Pat ParamType -> Pat ParamType +updAMPat :: RankMap -> Pat ParamType -> Pat ParamType updAMPat rank_map p = runIdentity $ astMap m p where - m = - identityMapper - { mapOnExp = pure . updAM rank_map - } + m = identityMapper {mapOnExp = pure . updAM rank_map} + +updAMTypeExp :: + RankMap -> + TypeExp Exp VName -> + TypeExp Exp VName +updAMTypeExp rank_map te = runIdentity $ astMap m te + where + m = identityMapper {mapOnExp = pure . updAM rank_map} diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index f35aa83de9..cd3d460793 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1356,10 +1356,10 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do unlines $ map (\(v, t) -> prettyNameString v <> " => " <> prettyString t) (M.toList artificial) ] - onRankSolution retdecl' typarams - =<< rankAnalysis1 loc cts tyvars artificial params' body' + onRankSolution typarams + =<< rankAnalysis1 loc cts tyvars artificial params' body' retdecl' where - onRankSolution retdecl' typarams ((cts', artificial, tyvars'), params', body'') = do + onRankSolution typarams ((cts', artificial, tyvars'), params', body'', retdecl') = do solution <- bitraverse pure @@ -1400,8 +1400,8 @@ checkSingleExp e = runTermM $ do tyvars <- gets termTyVars typarams <- gets termTyParams artificial <- gets termArtificial - ((cts', _artificial', tyvars'), _, e'') <- - rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' + ((cts', _artificial', tyvars'), _, e'', _) <- + rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' Nothing case solve cts' typarams tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do @@ -1421,7 +1421,8 @@ checkSizeExp e = runTermM $ do typarams <- gets termTyParams artificial <- gets termArtificial - (cts_tyvars', _, es') <- unzip3 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' + (cts_tyvars', _, es', _) <- + L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' Nothing solutions <- forM cts_tyvars' $ \(cts', _artificial', tyvars') -> From 5f94900f1fc4dcc19a6fa614b495a77f125388f3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jul 2024 23:53:20 +0200 Subject: [PATCH 235/296] Propagate liftedness properly. --- .../Futhark/TypeChecker/Constraints.hs | 80 +++++++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 8 +- 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 178c22e2b9..ca8f713d0d 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -128,13 +128,13 @@ type Level = Int type TyVars = M.Map TyVar (Level, TyVarInfo) -- | Explicit type parameters. -type TyParams = M.Map TyVar (Level, Loc) +type TyParams = M.Map TyVar (Level, Liftedness, Loc) data TyVarSol = -- | Has been substituted with this. TyVarSol Level Type | -- | Is an explicit (rigid) type parameter in the source program. - TyVarParam Level Loc + TyVarParam Level Liftedness Loc | -- | Not substituted yet; has this constraint. TyVarUnsol Level TyVarInfo deriving (Show) @@ -148,7 +148,7 @@ initialState :: TyParams -> TyVars -> SolverState initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars where f (lvl, info) = Right $ TyVarUnsol lvl info - g (lvl, loc) = Right $ TyVarParam lvl loc + g (lvl, l, loc) = Right $ TyVarParam lvl l loc substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = @@ -159,18 +159,24 @@ substTyVar m v = Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing -lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) -lookupTyVar orig = do +maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) +maybeLookupTyVar orig = do tyvars <- gets solverTyVars let f v = case M.lookup v tyvars of - Nothing -> error $ "Unknown tyvar: " <> prettyNameString v + Nothing -> pure Nothing Just (Left v') -> f v' - Just (Right (TyVarSol lvl t)) -> pure (lvl, Right t) - Just (Right (TyVarParam lvl _)) -> - pure (lvl, Right $ Scalar $ TypeVar mempty (qualName orig) []) - Just (Right (TyVarUnsol lvl info)) -> pure (lvl, Left info) + Just (Right info) -> pure $ Just info f orig +lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) +lookupTyVar orig = + maybe bad unpack <$> maybeLookupTyVar orig + where + bad = error $ "Unknown tyvar: " <> prettyNameString orig + unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig + unpack (TyVarSol lvl t) = (lvl, Right t) + unpack (TyVarUnsol lvl info) = (lvl, Left info) + -- | Variable must be flexible. lookupTyVarInfo :: TyVar -> SolveM (Level, TyVarInfo) lookupTyVarInfo v = do @@ -275,6 +281,11 @@ subTyVar :: Reason -> VName -> Int -> Type -> SolveM () subTyVar reason v v_lvl t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars + + -- Set a solution for v, then update info for t in case v has any + -- odd constraints. + setInfo v (TyVarSol v_lvl t) + case (v_info, t) of (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> pure () @@ -340,8 +351,6 @@ subTyVar reason v v_lvl t = do (Nothing, _) -> error $ "subTyVar: Nothing v: " <> prettyNameString v - setInfo v (TyVarSol v_lvl t) - -- Precondition: 'v' and 't' are both currently flexible. unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do @@ -566,10 +575,37 @@ scopeCheck reason v v_lvl ty = do check ty_v = do ty_v_info <- gets $ M.lookup ty_v . solverTyVars case ty_v_info of - Just (Right (TyVarParam ty_v_lvl _)) + Just (Right (TyVarParam ty_v_lvl _ _)) | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v _ -> pure () +-- If a type variable has a liftedness constraint, we propagate that +-- constraint to its solution. The actual checking for correct usage +-- is done later. +liftednessCheck :: Reason -> TyVar -> Liftedness -> Type -> SolveM () +liftednessCheck reason v l (Scalar (TypeVar _ (QualName [] v2) _)) = do + v2_info <- maybeLookupTyVar v2 + case v2_info of + Nothing -> + -- Is an opaque type. + pure () + Just (TyVarSol _ v2_ty) -> + liftednessCheck reason v l v2_ty + Just TyVarParam {} -> pure () + Just (TyVarUnsol lvl (TyVarFree loc v2_l)) + | l /= v2_l -> + setInfo v2 $ TyVarUnsol lvl $ TyVarFree loc (min l v2_l) + Just TyVarUnsol {} -> pure () +liftednessCheck _ _ _ (Scalar Prim {}) = pure () +liftednessCheck _ _ Lifted _ = pure () +liftednessCheck _ _ _ Array {} = pure () +liftednessCheck _ _ _ (Scalar Arrow {}) = pure () +liftednessCheck reason v l (Scalar (Record fs)) = + mapM_ (liftednessCheck reason v l) fs +liftednessCheck reason v l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck reason v l) cs +liftednessCheck _ _ _ (Scalar TypeVar {}) = pure () + solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do (_, tv_t) <- lookupTyVar tv @@ -607,11 +643,23 @@ solveTyVar (tv, (_, TyVarEql loc)) = do "Type" indent 2 (align (pretty ty)) "does not support equality (may contain function)." -solveTyVar (tv, (lvl, _)) = do +solveTyVar (tv, (lvl, TyVarFree loc l)) = do + (_, tv_t) <- lookupTyVar tv + case tv_t of + Right ty -> do + scopeCheck (Reason loc) tv lvl ty + liftednessCheck (Reason loc) tv l ty + _ -> pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do (_, tv_t) <- lookupTyVar tv case tv_t of - Right ty -> - scopeCheck (Reason mempty) tv lvl ty + Right ty + | ty `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + typeError loc mempty $ + "Numeric constant inferred to be of type" + indent 2 (align (pretty ty)) + "which is not possible." _ -> pure () solve :: diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cd3d460793..e58154801e 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -655,9 +655,8 @@ bindTypeParams tparams m = types = mapMaybe typeParamType tparams typeParamType (TypeParamType l v _) = Just (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])) - typeParamType TypeParamDim {} = - Nothing - typeParam lvl (TypeParamType _ v loc) = Just (v, (lvl, locOf loc)) + typeParamType TypeParamDim {} = Nothing + typeParam lvl (TypeParamType l v loc) = Just (v, (lvl, l, locOf loc)) typeParam _ _ = Nothing bindParams :: @@ -1370,7 +1369,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do [ "## constraints:", unlines $ map prettyString cts', "## typarams:", - unlines (map (prettyString . bimap prettyNameString fst) (M.toList typarams)), + let f (lvl, l, _) = (lvl, l) + in unlines (map (prettyString . bimap prettyNameString f) (M.toList typarams)), "## tyvars':", unlines $ map (prettyString . first prettyNameString) $ M.toList tyvars', "## solution:", From 041e9e05971bb853edbbb7ea120b435ae913f1cc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jul 2024 23:58:47 +0200 Subject: [PATCH 236/296] Simplify. --- .../Futhark/TypeChecker/Constraints.hs | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index ca8f713d0d..bc458d0616 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -582,29 +582,29 @@ scopeCheck reason v v_lvl ty = do -- If a type variable has a liftedness constraint, we propagate that -- constraint to its solution. The actual checking for correct usage -- is done later. -liftednessCheck :: Reason -> TyVar -> Liftedness -> Type -> SolveM () -liftednessCheck reason v l (Scalar (TypeVar _ (QualName [] v2) _)) = do - v2_info <- maybeLookupTyVar v2 - case v2_info of +liftednessCheck :: Liftedness -> Type -> SolveM () +liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do + v_info <- maybeLookupTyVar v + case v_info of Nothing -> -- Is an opaque type. pure () - Just (TyVarSol _ v2_ty) -> - liftednessCheck reason v l v2_ty + Just (TyVarSol _ v_ty) -> + liftednessCheck l v_ty Just TyVarParam {} -> pure () - Just (TyVarUnsol lvl (TyVarFree loc v2_l)) - | l /= v2_l -> - setInfo v2 $ TyVarUnsol lvl $ TyVarFree loc (min l v2_l) + Just (TyVarUnsol lvl (TyVarFree loc v_l)) + | l /= v_l -> + setInfo v $ TyVarUnsol lvl $ TyVarFree loc (min l v_l) Just TyVarUnsol {} -> pure () -liftednessCheck _ _ _ (Scalar Prim {}) = pure () -liftednessCheck _ _ Lifted _ = pure () -liftednessCheck _ _ _ Array {} = pure () -liftednessCheck _ _ _ (Scalar Arrow {}) = pure () -liftednessCheck reason v l (Scalar (Record fs)) = - mapM_ (liftednessCheck reason v l) fs -liftednessCheck reason v l (Scalar (Sum cs)) = - mapM_ (mapM_ $ liftednessCheck reason v l) cs -liftednessCheck _ _ _ (Scalar TypeVar {}) = pure () +liftednessCheck _ (Scalar Prim {}) = pure () +liftednessCheck Lifted _ = pure () +liftednessCheck _ Array {} = pure () +liftednessCheck _ (Scalar Arrow {}) = pure () +liftednessCheck l (Scalar (Record fs)) = + mapM_ (liftednessCheck l) fs +liftednessCheck l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck l) cs +liftednessCheck _ (Scalar TypeVar {}) = pure () solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do @@ -648,7 +648,7 @@ solveTyVar (tv, (lvl, TyVarFree loc l)) = do case tv_t of Right ty -> do scopeCheck (Reason loc) tv lvl ty - liftednessCheck (Reason loc) tv l ty + liftednessCheck l ty _ -> pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do (_, tv_t) <- lookupTyVar tv From 9834d6866e74620b0328775b1ddf0537cda9bf5d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 00:08:24 +0200 Subject: [PATCH 237/296] Simplify the level stuff. --- .../Futhark/TypeChecker/Constraints.hs | 140 +++++++++--------- 1 file changed, 68 insertions(+), 72 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index bc458d0616..1b66ef9b44 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -132,11 +132,11 @@ type TyParams = M.Map TyVar (Level, Liftedness, Loc) data TyVarSol = -- | Has been substituted with this. - TyVarSol Level Type + TyVarSol Type | -- | Is an explicit (rigid) type parameter in the source program. TyVarParam Level Liftedness Loc | -- | Not substituted yet; has this constraint. - TyVarUnsol Level TyVarInfo + TyVarUnsol TyVarInfo deriving (Show) newtype SolverState = SolverState @@ -147,14 +147,14 @@ newtype SolverState = SolverState initialState :: TyParams -> TyVars -> SolverState initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars where - f (lvl, info) = Right $ TyVarUnsol lvl info + f (_lvl, info) = Right $ TyVarUnsol info g (lvl, l, loc) = Right $ TyVarParam lvl l loc substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) substTyVar m v = case M.lookup v m of Just (Left v') -> substTyVar m v' - Just (Right (TyVarSol _ t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' Just (Right TyVarParam {}) -> Nothing Just (Right (TyVarUnsol {})) -> Nothing Nothing -> Nothing @@ -168,21 +168,21 @@ maybeLookupTyVar orig = do Just (Right info) -> pure $ Just info f orig -lookupTyVar :: TyVar -> SolveM (Level, Either TyVarInfo Type) +lookupTyVar :: TyVar -> SolveM (Either TyVarInfo Type) lookupTyVar orig = maybe bad unpack <$> maybeLookupTyVar orig where bad = error $ "Unknown tyvar: " <> prettyNameString orig unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig - unpack (TyVarSol lvl t) = (lvl, Right t) - unpack (TyVarUnsol lvl info) = (lvl, Left info) + unpack (TyVarSol t) = Right t + unpack (TyVarUnsol info) = Left info -- | Variable must be flexible. -lookupTyVarInfo :: TyVar -> SolveM (Level, TyVarInfo) +lookupTyVarInfo :: TyVar -> SolveM TyVarInfo lookupTyVarInfo v = do - (lvl, r) <- lookupTyVar v + r <- lookupTyVar v case r of - Left info -> pure (lvl, info) + Left info -> pure info Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v setLink :: TyVar -> VName -> SolveM () @@ -206,15 +206,15 @@ solution s = M.mapMaybe mkSubst $ solverTyVars s ) where - mkSubst (Right (TyVarSol _lvl t)) = + mkSubst (Right (TyVarSol t)) = Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t mkSubst (Left v') = Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (Right (TyVarUnsol _ (TyVarPrim _ pts))) = Just $ Left pts + mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts mkSubst _ = Nothing - unconstrained (v, Right (TyVarUnsol _ (TyVarFree _ l))) = Just (v, l) + unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) unconstrained _ = Nothing newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} @@ -277,19 +277,19 @@ scopeViolation reason v1 ty v2 = <+> "is rigidly bound in a deeper scope." -- Precondition: 'v' is currently flexible. -subTyVar :: Reason -> VName -> Int -> Type -> SolveM () -subTyVar reason v v_lvl t = do +subTyVar :: Reason -> VName -> Type -> SolveM () +subTyVar reason v t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars -- Set a solution for v, then update info for t in case v has any -- odd constraints. - setInfo v (TyVarSol v_lvl t) + setInfo v (TyVarSol t) case (v_info, t) of - (Just (Right (TyVarUnsol _ TyVarFree {})), _) -> + (Just (Right (TyVarUnsol TyVarFree {})), _) -> pure () - ( Just (Right (TyVarUnsol _ (TyVarPrim _ v_pts))), + ( Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), _ ) -> if t `elem` map (Scalar . Prim) v_pts @@ -300,7 +300,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty v_pts) "with" indent 2 (pretty t) - ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), Scalar (Sum cs2) ) -> if all (`elem` M.keys cs2) (M.keys cs1) @@ -311,7 +311,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Sum cs2)) - ( Just (Right (TyVarUnsol _ (TyVarSum _ cs1))), + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), _ ) -> typeError (locOf reason) mempty $ @@ -319,7 +319,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty t) - ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), Scalar (Record fs2) ) -> if all (`elem` M.keys fs2) (M.keys fs1) @@ -330,7 +330,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Record fs1)) "with record type" indent 2 (pretty (Record fs2)) - ( Just (Right (TyVarUnsol _ (TyVarRecord _ fs1))), + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), _ ) -> typeError (locOf reason) mempty $ @@ -338,7 +338,7 @@ subTyVar reason v v_lvl t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty t) - (Just (Right (TyVarUnsol _ (TyVarEql _))), _) -> + (Just (Right (TyVarUnsol (TyVarEql _))), _) -> mustSupportEql reason t -- -- Internal error cases @@ -355,32 +355,32 @@ subTyVar reason v v_lvl t = do unionTyVars :: Reason -> VName -> VName -> SolveM () unionTyVars reason v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - (t_lvl, t_info) <- lookupTyVarInfo t + t_info <- lookupTyVarInfo t -- Insert the link from v to t, and then update the info of t based -- on the existing info of v and t. setLink v t case (v_info, t_info) of - ( TyVarUnsol _ (TyVarFree _ v_l), + ( TyVarUnsol (TyVarFree _ v_l), TyVarFree t_loc t_l ) | v_l /= t_l -> - setInfo t $ TyVarUnsol t_lvl $ TyVarFree t_loc (min v_l t_l) + setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) -- When either is completely unconstrained. - (TyVarUnsol _ TyVarFree {}, _) -> + (TyVarUnsol TyVarFree {}, _) -> pure () - ( TyVarUnsol _ info, + ( TyVarUnsol info, TyVarFree {} ) -> - setInfo t (TyVarUnsol t_lvl info) + setInfo t (TyVarUnsol info) -- -- TyVarPrim cases - ( TyVarUnsol _ info@TyVarPrim {}, + ( TyVarUnsol info@TyVarPrim {}, TyVarEql {} ) -> - setInfo t (TyVarUnsol t_lvl info) - ( TyVarUnsol _ (TyVarPrim _ v_pts), + setInfo t (TyVarUnsol info) + ( TyVarUnsol (TyVarPrim _ v_pts), TyVarPrim t_loc t_pts ) -> let pts = L.intersect v_pts t_pts @@ -391,15 +391,15 @@ unionTyVars reason v t = do indent 2 (pretty v_pts) "with type that must be one of" indent 2 (pretty t_pts) - else setInfo t (TyVarUnsol t_lvl (TyVarPrim t_loc pts)) - ( TyVarUnsol _ (TyVarPrim _ v_pts), + else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) + ( TyVarUnsol (TyVarPrim _ v_pts), TyVarRecord {} ) -> typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) "with type that must be record." - ( TyVarUnsol _ (TyVarPrim _ v_pts), + ( TyVarUnsol (TyVarPrim _ v_pts), TyVarSum {} ) -> typeError (locOf reason) mempty $ @@ -408,19 +408,19 @@ unionTyVars reason v t = do "with type that must be sum." -- -- TyVarSum cases - ( TyVarUnsol _ (TyVarSum _ cs1), + ( TyVarUnsol (TyVarSum _ cs1), TyVarSum loc cs2 ) -> do unifySharedConstructors reason cs1 cs2 let cs3 = cs1 <> cs2 - setInfo t (TyVarUnsol t_lvl (TyVarSum loc cs3)) - ( TyVarUnsol _ TyVarSum {}, + setInfo t (TyVarUnsol (TyVarSum loc cs3)) + ( TyVarUnsol TyVarSum {}, TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A sum type cannot be one of" indent 2 (pretty pts) - ( TyVarUnsol _ (TyVarSum _ cs1), + ( TyVarUnsol (TyVarSum _ cs1), TyVarRecord _ fs ) -> typeError (locOf reason) mempty $ @@ -428,25 +428,25 @@ unionTyVars reason v t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Scalar (Record fs))) - ( TyVarUnsol _ (TyVarSum _ cs1), + ( TyVarUnsol (TyVarSum _ cs1), TyVarEql _ ) -> mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases - ( TyVarUnsol _ (TyVarRecord _ fs1), + ( TyVarUnsol (TyVarRecord _ fs1), TyVarRecord loc fs2 ) -> do unifySharedFields reason fs1 fs2 let fs3 = fs1 <> fs2 - setInfo t (TyVarUnsol t_lvl (TyVarRecord loc fs3)) - ( TyVarUnsol _ TyVarRecord {}, + setInfo t (TyVarUnsol (TyVarRecord loc fs3)) + ( TyVarUnsol TyVarRecord {}, TyVarPrim _ pts ) -> typeError (locOf reason) mempty $ "A record type cannot be one of" indent 2 (pretty pts) - ( TyVarUnsol _ (TyVarRecord _ fs1), + ( TyVarUnsol (TyVarRecord _ fs1), TyVarSum _ cs ) -> typeError (locOf reason) mempty $ @@ -454,19 +454,19 @@ unionTyVars reason v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty (Scalar (Sum cs))) - ( TyVarUnsol _ (TyVarRecord _ fs1), + ( TyVarUnsol (TyVarRecord _ fs1), TyVarEql _ ) -> mapM_ (mustSupportEql reason) fs1 -- -- TyVarEql cases - (TyVarUnsol _ (TyVarEql _), TyVarPrim {}) -> + (TyVarUnsol (TyVarEql _), TyVarPrim {}) -> pure () - (TyVarUnsol _ (TyVarEql _), TyVarEql {}) -> + (TyVarUnsol (TyVarEql _), TyVarEql {}) -> pure () - (TyVarUnsol _ (TyVarEql _), TyVarRecord _ fs) -> + (TyVarUnsol (TyVarEql _), TyVarRecord _ fs) -> mustSupportEql reason $ Scalar $ Record fs - (TyVarUnsol _ (TyVarEql _), TyVarSum _ cs) -> + (TyVarUnsol (TyVarEql _), TyVarSum _ cs) -> mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases @@ -529,14 +529,14 @@ solveEq reason orig_t1 orig_t2 = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of Just (Left v') -> flexible v' - Just (Right (TyVarUnsol lvl _)) -> Just lvl - Just (Right TyVarSol {}) -> Nothing - Just (Right TyVarParam {}) -> Nothing - Nothing -> Nothing + Just (Right (TyVarUnsol _)) -> True + Just (Right TyVarSol {}) -> False + Just (Right TyVarParam {}) -> False + Nothing -> False sub t@(Scalar (TypeVar u (QualName [] v) [])) = case M.lookup v tyvars of Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (Right (TyVarSol _ t')) -> sub t' + Just (Right (TyVarSol t')) -> sub t' _ -> t sub t = t case (sub t1, sub t2) of @@ -546,18 +546,14 @@ solveEq reason orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (Nothing, Nothing) -> cannotUnify - (Just lvl, Nothing) -> subTyVar reason v1 lvl t2' - (Nothing, Just lvl) -> subTyVar reason v2 lvl t1' - (Just lvl1, Just lvl2) - | lvl1 <= lvl2 -> unionTyVars reason v1 v2 - | otherwise -> unionTyVars reason v2 v1 + (False, False) -> cannotUnify + (True, False) -> subTyVar reason v1 t2' + (False, True) -> subTyVar reason v2 t1' + (True, True) -> unionTyVars reason v1 v2 (Scalar (TypeVar _ (QualName [] v1) []), t2') - | Just lvl <- flexible v1 -> - subTyVar reason v1 lvl t2' + | flexible v1 -> subTyVar reason v1 t2' (t1', Scalar (TypeVar _ (QualName [] v2) [])) - | Just lvl <- flexible v2 -> - subTyVar reason v2 lvl t1' + | flexible v2 -> subTyVar reason v2 t1' (t1', t2') -> case unify t1' t2' of Nothing -> cannotUnify Just eqs -> mapM_ solveCt' eqs @@ -589,12 +585,12 @@ liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do Nothing -> -- Is an opaque type. pure () - Just (TyVarSol _ v_ty) -> + Just (TyVarSol v_ty) -> liftednessCheck l v_ty Just TyVarParam {} -> pure () - Just (TyVarUnsol lvl (TyVarFree loc v_l)) + Just (TyVarUnsol (TyVarFree loc v_l)) | l /= v_l -> - setInfo v $ TyVarUnsol lvl $ TyVarFree loc (min l v_l) + setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) Just TyVarUnsol {} -> pure () liftednessCheck _ (Scalar Prim {}) = pure () liftednessCheck Lifted _ = pure () @@ -608,7 +604,7 @@ liftednessCheck _ (Scalar TypeVar {}) = pure () solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Left _ -> typeError loc mempty $ @@ -620,7 +616,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do Right _ -> pure () solveTyVar (tv, (_, TyVarSum loc cs1)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Left _ -> typeError loc mempty $ @@ -629,7 +625,7 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () solveTyVar (tv, (_, TyVarEql loc)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Left TyVarEql {} -> typeError loc mempty $ @@ -644,14 +640,14 @@ solveTyVar (tv, (_, TyVarEql loc)) = do indent 2 (align (pretty ty)) "does not support equality (may contain function)." solveTyVar (tv, (lvl, TyVarFree loc l)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Right ty -> do scopeCheck (Reason loc) tv lvl ty liftednessCheck l ty _ -> pure () solveTyVar (tv, (_, TyVarPrim loc pts)) = do - (_, tv_t) <- lookupTyVar tv + tv_t <- lookupTyVar tv case tv_t of Right ty | ty `elem` map (Scalar . Prim) pts -> pure () From e5088b58c06274bd4d2b6673a19f8bed5f297c3d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 10:22:50 +0200 Subject: [PATCH 238/296] Improve handling of branches. --- src/Language/Futhark/TypeChecker/Unify.hs | 29 +++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 47e04ee3aa..8bd5b68af7 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -807,10 +807,31 @@ unifyMostCommon :: StructType -> m (StructType, [VName]) unifyMostCommon usage t1 t2 = do - -- We are ignoring the dimensions here, because any mismatches - -- should be turned into fresh size variables. - let allOK _ _ _ _ _ = pure () - unifyWith allOK usage mempty noBreadCrumbs t1 t2 + -- Like 'unifySizes', except we do not fail on mismatches - these + -- are instead turned into fresh existential sizes in + -- 'newDimOnMismatch'. The most annoying thing is that we have to + -- replicate scope checking, because we don't want to link if it + -- would fail. + constraints <- getConstraints + + let varLevel v = fst <$> M.lookup v constraints + expLevel e = + L.foldl' max 0 $ mapMaybe varLevel $ S.toList $ fvVars $ freeInExp e + + onDims bcs bound nonrigid e1 e2 + | Just es <- similarExps e1 e2 = + mapM_ (uncurry $ onDims bcs bound nonrigid) es + onDims bcs _ nonrigid (Var v1 _ _) e2 + | Just lvl1 <- nonrigid (qualLeaf v1), + expLevel e2 < lvl1 = + linkVarToDim usage bcs (qualLeaf v1) lvl1 e2 + onDims bcs _ nonrigid e1 (Var v2 _ _) + | Just lvl2 <- nonrigid (qualLeaf v2), + expLevel e1 < lvl2 = + linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 + onDims _ _ _ _ _ = pure () + + unifyWith onDims usage mempty noBreadCrumbs t1 t2 t1' <- normTypeFully t1 t2' <- normTypeFully t2 newDimOnMismatch (locOf usage) t1' t2' From 104641407905ded29f48818407bf5c66e5457498 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 10:24:49 +0200 Subject: [PATCH 239/296] Supposed to be ambiguous. --- tests/automap/ambiguous0.fut | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut index 58a663bf36..8c1ec556c3 100644 --- a/tests/automap/ambiguous0.fut +++ b/tests/automap/ambiguous0.fut @@ -1 +1,4 @@ +-- == +-- error: ambiguous + def ambig (xss : [][]i32) = i64.sum (length xss) From d0bea363dfcaa6e5b726e4dbf64832d84cbccaed Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 11:02:22 +0200 Subject: [PATCH 240/296] Now inferred differently. --- tests/tridag.fut | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/tridag.fut b/tests/tridag.fut index a055dca86a..e8cc6718e8 100644 --- a/tests/tridag.fut +++ b/tests/tridag.fut @@ -34,32 +34,31 @@ -- } -def tridag(nn: i32, - b: *[]f64, d: *[]f64, - a: []f64, c: []f64 ): ([]f64,[]f64) = - if (nn == 1) +def tridag [nn] (b: *[]f64, d: *[nn]f64, + a: []f64, c: []f64 ): ([]f64,[]f64) = + if (nn == 1) --then ( b, map(\f64 (f64 x, f64 y) -> x / y, d, b) ) then (b, [d[0]/b[0]]) - else - let (b,d) = loop((b, d)) for i < (nn-1) do - let xm = a[i+1] / b[i] - let b[i+1] = b[i+1] - xm*c[i] - let d[i+1] = d[i+1] - xm*d[i] in - (b, d) + else + let (b,d) = loop((b, d)) for i < (nn-1) do + let xm = a[i+1] / b[i] + let b[i+1] = b[i+1] - xm*c[i] + let d[i+1] = d[i+1] - xm*d[i] in + (b, d) - let d[nn-1] = d[nn-1] / b[nn-1] in + let d[nn-1] = d[nn-1] / b[nn-1] in - let d = loop(d) for i < (nn-1) do - let k = nn - 2 - i - let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in - d - in (b, d) + let d = loop(d) for i < (nn-1) do + let k = nn - 2 - i + let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in + d + in (b, d) def main: ([]f64,[]f64) = - let nn = reduce (+) 0 ([1,2,3,4]) - let a = replicate nn 3.33 - let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) - let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) - let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) - in tridag(i32.i64 nn, b, d, a, c) + let nn = reduce (+) 0 ([1,2,3,4]) + let a = replicate nn 3.33 + let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) + let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) + let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) + in tridag(b, d, a, c) From 0a07d320d058ed99b21affa891f5779b6360f141 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jul 2024 22:16:12 +0200 Subject: [PATCH 241/296] Must be more explicit now. --- tests/shapes/polymorphic4.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shapes/polymorphic4.fut b/tests/shapes/polymorphic4.fut index b44af86c34..acab851f67 100644 --- a/tests/shapes/polymorphic4.fut +++ b/tests/shapes/polymorphic4.fut @@ -2,6 +2,6 @@ -- == -- error: do not match -def foo f x : [1]i32 = +def foo (f : (n: i64) -> [n]i32) x : [1]i32 = let r = if true then f x : []i32 else [1i32] in r From 80cfae5c78b307aa45e38a7b28e1a425153f2b8f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 16 Jul 2024 17:46:11 +0200 Subject: [PATCH 242/296] This is OK now. --- tests/sumtypes/coerce1.fut | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/sumtypes/coerce1.fut b/tests/sumtypes/coerce1.fut index eeff92a2a3..b6bfe42f3d 100644 --- a/tests/sumtypes/coerce1.fut +++ b/tests/sumtypes/coerce1.fut @@ -1,5 +1,4 @@ -- == --- error: Ambiguous size.*anonymous size type opt 't = #some t | #none From 66be58ecb9b8011fb0fe99b2e519984da28fabc9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 16 Jul 2024 17:46:25 +0200 Subject: [PATCH 243/296] Fix Constr. --- src/Language/Futhark/TypeChecker/Terms.hs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index d2f03ecc9c..06125f214d 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -796,6 +796,13 @@ checkExp (AppExp (Loop _ mergepat mergeexp form loopbody loc) _) = do checkExp (Constr name es (Info t) loc) = do t' <- replaceTyVars loc t es' <- mapM checkExp es + case t' of + Scalar (Sum cs) + | Just name_ts <- M.lookup name cs -> + zipWithM_ (unify $ mkUsage loc "inferred variant") name_ts $ + map typeOf es' + _ -> + error $ "checkExp Constr: " <> prettyString t' pure $ Constr name es' (Info t') loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e From 7068b88970e7d118f038731f8f5f16853f0d3954 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 16 Jul 2024 23:02:56 +0200 Subject: [PATCH 244/296] Fiddle with liftedness checking. --- src/Language/Futhark/TypeChecker/Terms.hs | 42 +++++++++---------- .../Futhark/TypeChecker/Terms/Loop.hs | 6 +-- .../Futhark/TypeChecker/Terms/Monad.hs | 29 +++++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 9 ++-- src/Language/Futhark/TypeChecker/Unify.hs | 35 ---------------- 5 files changed, 56 insertions(+), 65 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 06125f214d..082e60f1ea 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -383,12 +383,14 @@ checkExp (ArrayLit all_es _ loc) = [] -> do et <- newTypeVar loc "t" t <- arrayOfM loc et (Shape [sizeFromInteger 0 mempty]) + mustBeUnlifted (locOf loc) et pure $ ArrayLit [] (Info t) loc e : es -> do e' <- checkExp e et <- expType e' es' <- mapM (unifies "type of first array element" et <=< checkExp) es t <- arrayOfM loc et (Shape [sizeFromInteger (genericLength all_es) mempty]) + mustBeUnlifted (locOf loc) et pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- checkExp start @@ -519,24 +521,6 @@ checkExp (Project k e _ loc) = do | Just kt <- M.lookup k fs -> pure $ Project k e' (Info kt) loc _ -> error $ "checkExp Project: " <> show t -checkExp (AppExp (If e1 e2 e3 loc) _) = do - e1' <- checkExp e1 - e2' <- checkExp e2 - e3' <- checkExp e3 - - let bool = Scalar $ Prim Bool - e1_t <- expType e1' - onFailure (CheckingRequired [bool] e1_t) $ - unify (mkUsage e1' "use as 'if' condition") bool e1_t - - (brancht, retext) <- unifyBranches loc e2' e3' - - zeroOrderType - (mkUsage loc "returning value of this type from 'if' expression") - "type returned from branch" - brancht - - pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes brancht retext) checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (QualParens (modname, modnameloc) e loc) = do @@ -804,14 +788,28 @@ checkExp (Constr name es (Info t) loc) = do _ -> error $ "checkExp Constr: " <> prettyString t' pure $ Constr name es' (Info t') loc +checkExp (AppExp (If e1 e2 e3 loc) _) = do + e1' <- checkExp e1 + e2' <- checkExp e2 + e3' <- checkExp e3 + + let bool = Scalar $ Prim Bool + e1_t <- expType e1' + onFailure (CheckingRequired [bool] e1_t) $ + unify (mkUsage e1' "use as 'if' condition") bool e1_t + + (t, retext) <- unifyBranches loc e2' e3' + + mustBeOrderZero (locOf loc) t + + pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes t retext) checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e mt <- expType e' (cs', t, retext) <- checkCases mt cs - zeroOrderType - (mkUsage loc "being returned 'match'") - "type returned from pattern match" - t + + mustBeOrderZero (locOf loc) t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes t retext) checkExp (Attr info e loc) = Attr <$> checkAttr info <*> checkExp e <*> pure loc diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index cf447e7408..334c67ed5b 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -136,11 +136,7 @@ checkLoop :: checkLoop checkExp (mergepat, mergeexp, form, loopbody) loc = do mergeexp' <- checkExp mergeexp known_before <- M.keysSet <$> getConstraints - zeroOrderType - (mkUsage mergeexp "use as loop variable") - "type used as loop variable" - . toStruct - =<< expTypeFully mergeexp' + mustBeOrderZero (locOf mergeexp) =<< expTypeFully mergeexp' -- The handling of dimension sizes is a bit intricate, but very -- similar to checking a function, followed by checking a call to diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 4d9be5deff..36d9c8bade 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -29,6 +29,8 @@ module Language.Futhark.TypeChecker.Terms.Monad replaceTyVars, updateTypes, Names, + mustBeOrderZero, + mustBeUnlifted, -- * Primitive checking unifies, @@ -618,6 +620,33 @@ updateTypes = astMap tv mapOnResRetType = normTypeFully } +mustBeOrderZero :: Loc -> StructType -> TermTypeM () +mustBeOrderZero loc t = do + constraints <- getConstraints + let liftedType v = + case M.lookup v constraints of + Just (_, ParamType Lifted _) -> True + _ -> False + when (not (orderZero t) || any liftedType (typeVars t)) $ + typeError loc mempty $ + textwrap "This expression may not be of function type, but is inferred to be of type" + indent 2 (align (pretty t)) + "which may be a function." + +mustBeUnlifted :: Loc -> StructType -> TermTypeM () +mustBeUnlifted loc t = do + constraints <- getConstraints + let liftedType v = + case M.lookup v constraints of + Just (_, ParamType Lifted _) -> True + Just (_, ParamType SizeLifted _) -> True + _ -> False + when (not (orderZero t) || any liftedType (typeVars t)) $ + typeError loc mempty $ + textwrap "This expression must be of unlifted type, but is inferred to be of type" + indent 2 (align (pretty t)) + "which may be a function or a value with hidden sizes." + --- Basic checking unifies :: T.Text -> StructType -> Exp -> TermTypeM Exp diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index e58154801e..8fabb67962 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1191,10 +1191,13 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e e_t <- expType e' - (cs', t) <- checkCases e_t cs - t' <- asStructType t - pure $ AppExp (Match e' cs' loc) (Info $ AppRes t' []) + + match_t <- newType loc SizeLifted "match_t" NoUniqueness + ctEq (Reason (locOf loc)) match_t t + + match_t' <- asStructType match_t + pure $ AppExp (Match e' cs' loc) (Info $ AppRes match_t' []) -- checkExp (AppExp (Loop _ pat arg form body loc) _) = do arg' <- checkExp arg diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 8bd5b68af7..7d0f39bf91 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -14,7 +14,6 @@ module Language.Futhark.TypeChecker.Unify noBreadCrumbs, hasNoBreadCrumbs, dimNotes, - zeroOrderType, arrayElemType, normType, normTypeFully, @@ -665,40 +664,6 @@ linkVarToDim usage bcs vn lvl e = do _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () -zeroOrderTypeWith :: - (MonadUnify m) => - Usage -> - BreadCrumbs -> - StructType -> - m () -zeroOrderTypeWith usage bcs t = do - unless (orderZero t) $ - unifyError usage mempty bcs $ - "Type" indent 2 (pretty t) "found to be functional." - mapM_ mustBeZeroOrder . S.toList . typeVars =<< normType t - where - mustBeZeroOrder vn = do - constraints <- getConstraints - case M.lookup vn constraints of - Just (lvl, NoConstraint _ _) -> - modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) - Just (_, ParamType Lifted ploc) -> - unifyError usage mempty bcs $ - "Type parameter" - <+> dquotes (prettyName vn) - <+> "at" - <+> pretty (locStr ploc) - <+> "may be a function." - _ -> pure () - --- | Assert that this type must be zero-order. -zeroOrderType :: - (MonadUnify m) => Usage -> T.Text -> StructType -> m () -zeroOrderType usage desc = - zeroOrderTypeWith usage $ breadCrumb bc noBreadCrumbs - where - bc = Matching $ "When checking" <+> textwrap desc - arrayElemTypeWith :: (MonadUnify m, Pretty (Shape dim), Pretty u) => Usage -> From 2807accc420ed887c92df1249fea4e48f549398f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 4 Sep 2024 16:34:16 +0200 Subject: [PATCH 245/296] Work on error message. --- .../Futhark/TypeChecker/Constraints.hs | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1b66ef9b44..16b5650e94 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -481,49 +481,57 @@ unionTyVars reason v t = do isParam = error $ "Type name is a type parameter: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. -unify :: Type -> Type -> Maybe [(Type, Type)] +unify :: Type -> Type -> Either (Doc a) [(Type, Type)] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) - | pt1 == pt2 = Just [] + | pt1 == pt2 = Right [] unify (Scalar (TypeVar _ (QualName _ v1) targs1)) (Scalar (TypeVar _ (QualName _ v2) targs2)) | v1 == v2 = - Just $ mapMaybe f $ zip targs1 targs2 + Right $ mapMaybe f $ zip targs1 targs2 where f (TypeArgType t1, TypeArgType t2) = Just (t1, t2) f _ = Nothing unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Just [(t1a, t2a), (t1r', t2r')] + Right [(t1a, t2a), (t1r', t2r')] where t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = - Just $ M.elems $ M.intersectionWith (,) fs1 fs2 + Right $ M.elems $ M.intersectionWith (,) fs1 fs2 + | otherwise = + let missing = + filter (`notElem` M.keys fs1) (M.keys fs2) + <> filter (`notElem` M.keys fs2) (M.keys fs1) + in Left $ + "Unshared fields:" <+> commasep (map pretty missing) <> "." unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = fmap concat . forM cs' $ \(ts1, ts2) -> do - guard $ length ts1 == length ts2 - Just $ zip ts1 ts2 + if length ts1 == length ts2 + then Right $ zip ts1 ts2 + else Left mempty where cs' = M.elems $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = - Just [(t1', t2')] -unify _ _ = Nothing + Right [(t1', t2')] +unify _ _ = Left mempty solveEq :: Reason -> Type -> Type -> SolveM () solveEq reason orig_t1 orig_t2 = do solveCt' (orig_t1, orig_t2) where - cannotUnify = do + cannotUnify details = do tyvars <- gets solverTyVars typeError (locOf reason) mempty $ "Cannot unify" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) "with" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) + <> details solveCt' (t1, t2) = do tyvars <- gets solverTyVars @@ -546,7 +554,7 @@ solveEq reason orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify + (False, False) -> cannotUnify mempty (True, False) -> subTyVar reason v1 t2' (False, True) -> subTyVar reason v2 t1' (True, True) -> unionTyVars reason v1 v2 @@ -555,8 +563,8 @@ solveEq reason orig_t1 orig_t2 = do (t1', Scalar (TypeVar _ (QualName [] v2) [])) | flexible v2 -> subTyVar reason v2 t1' (t1', t2') -> case unify t1' t2' of - Nothing -> cannotUnify - Just eqs -> mapM_ solveCt' eqs + Left details -> cannotUnify details + Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () solveCt ct = From d8dbec2b5451be86b02d7c1cab436902f24ec3ad Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:13:39 +0200 Subject: [PATCH 246/296] Fix some mistaken tests. --- tests/shapes/error6.fut | 2 +- tests/shapes/shape_duplicate.fut | 4 ++-- tests/shapes/size-inference2.fut | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/shapes/error6.fut b/tests/shapes/error6.fut index 3fda73dd6e..5c7332d94a 100644 --- a/tests/shapes/error6.fut +++ b/tests/shapes/error6.fut @@ -2,7 +2,7 @@ -- == -- error: "n" -def ap (f: (n: i64) -> [n]i32) (k: i64) : [k]i32 = +def ap (f: (n: i64) -> [n]i64) (k: i64) : [k]i64 = f k def main = ap (\n -> iota (n+1)) 10 diff --git a/tests/shapes/shape_duplicate.fut b/tests/shapes/shape_duplicate.fut index 3bbd5f391f..b29e1e7cbe 100644 --- a/tests/shapes/shape_duplicate.fut +++ b/tests/shapes/shape_duplicate.fut @@ -4,7 +4,7 @@ -- == -- error: do not match -def f [n][m] ((_, elems: [n]i32): (i32,[m]i32)): i32 = +def f [n][m] ((_, elems: [n]i64): (i64,[m]i64)): i64 = n + m + elems[0] -def main (x: i32, y: []i32): i32 = f (x, y) +def main (x: i64, y: []i64): i64 = f (x, y) diff --git a/tests/shapes/size-inference2.fut b/tests/shapes/size-inference2.fut index b6f59d4a9a..2804383f72 100644 --- a/tests/shapes/size-inference2.fut +++ b/tests/shapes/size-inference2.fut @@ -2,4 +2,4 @@ -- == -- error: Sizes.*do not match -def main [n] (xs: [n]i32) : [n]i32 = iota (length xs) +def main [n] (xs: [n]i32) : [n]i64 = iota (length xs) From 90d36781ff5c5e7da0fbca5bbc5c1f6c8caa9b5b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:15:18 +0200 Subject: [PATCH 247/296] Linebreak. --- src/Language/Futhark/TypeChecker/Constraints.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 16b5650e94..351e738f33 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -531,7 +531,7 @@ solveEq reason orig_t1 orig_t2 = do indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) "with" indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) - <> details + details solveCt' (t1, t2) = do tyvars <- gets solverTyVars From bc804c1cda489692196b6cba14c0adfba612500e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:19:36 +0200 Subject: [PATCH 248/296] Fix another test. --- tests/shapes/error4.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shapes/error4.fut b/tests/shapes/error4.fut index b842bdf44a..cf75bfe897 100644 --- a/tests/shapes/error4.fut +++ b/tests/shapes/error4.fut @@ -2,7 +2,7 @@ -- == -- error: Sizes.*"n".*do not match -def f (g: (n: i64) -> [n]i32) (l: i64): i32 = +def f (g: (n: i64) -> [n]i64) (l: i64): i64 = (g l)[0] def main = f (\n : []i64 -> iota (n+1)) From 2a0cc26e09032980080c2e041c551e9a7b3651db Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 09:41:23 +0200 Subject: [PATCH 249/296] Fix more error messages. --- tests/issue1787.fut | 2 +- tests/issue514.fut | 2 +- tests/types/inference-error4.fut | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/issue1787.fut b/tests/issue1787.fut index 90cb01dd72..ed4aef3fae 100644 --- a/tests/issue1787.fut +++ b/tests/issue1787.fut @@ -1,5 +1,5 @@ -- == --- error: found to be functional +-- error: function type entry main: i32 -> i32 -> i32 = ((true, (.0)), (false, (.1))) diff --git a/tests/issue514.fut b/tests/issue514.fut index 2f70eca04f..057d69b71a 100644 --- a/tests/issue514.fut +++ b/tests/issue514.fut @@ -1,4 +1,4 @@ -- == --- error: issue514.fut:4:26-36 +-- error: issue514.fut:4:13-22 def main = (2.0 + 3.0) / (2 + 3i32) diff --git a/tests/types/inference-error4.fut b/tests/types/inference-error4.fut index 809b98302a..0ff781f33a 100644 --- a/tests/types/inference-error4.fut +++ b/tests/types/inference-error4.fut @@ -1,6 +1,6 @@ -- If something is used in a loop, it cannot later be inferred as a -- function. -- == --- error: functional +-- error: function type def f x = (loop x = x for i < 10 do x, x 2) From b77f1c2d742dace00425ea2746c6869966eaaec4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 10:57:28 +0200 Subject: [PATCH 250/296] Less weird. --- tests/ascription0.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ascription0.fut b/tests/ascription0.fut index 5aff8c054a..8c3a50e026 100644 --- a/tests/ascription0.fut +++ b/tests/ascription0.fut @@ -3,6 +3,6 @@ -- == -- error: match -def main(x: i32, y:i32): i32 = +def main(x: i32, y:i32): (bool,bool) = let (((a): i32), b: i32) : (bool,bool) = (x,y) in (a,b) From dee11882979616b2a8e7e8189ce3932914c0410f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 11:12:30 +0200 Subject: [PATCH 251/296] Introduce breadcrumbs in constraint solver. --- futhark.cabal | 1 + .../Futhark/TypeChecker/Constraints.hs | 66 ++++++++------- src/Language/Futhark/TypeChecker/Error.hs | 79 ++++++++++++++++++ .../Futhark/TypeChecker/Terms/Monad.hs | 1 + src/Language/Futhark/TypeChecker/Unify.hs | 83 ++++--------------- 5 files changed, 132 insertions(+), 98 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/Error.hs diff --git a/futhark.cabal b/futhark.cabal index 054e5e2917..bcd12b7a12 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -421,6 +421,7 @@ library Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption Language.Futhark.TypeChecker.Constraints + Language.Futhark.TypeChecker.Error Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 351e738f33..9c9e7869c9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -27,6 +27,7 @@ import Data.Maybe import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..)) import Language.Futhark.TypeChecker.Types (substTyVars) @@ -237,13 +238,14 @@ occursCheck reason v tp = do unifySharedConstructors :: Reason -> + BreadCrumbs -> M.Map Name [Type] -> M.Map Name [Type] -> SolveM () -unifySharedConstructors reason cs1 cs2 = +unifySharedConstructors reason bcs cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> if length ts1 == length ts2 - then zipWithM_ (solveEq reason) ts1 ts2 + then zipWithM_ (solveEq reason bcs) ts1 ts2 else typeError (locOf reason) mempty $ "Cannot unify type with constructor" @@ -254,12 +256,13 @@ unifySharedConstructors reason cs1 cs2 = unifySharedFields :: Reason -> + BreadCrumbs -> M.Map Name Type -> M.Map Name Type -> SolveM () -unifySharedFields reason fs1 fs2 = +unifySharedFields reason bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_f, (ts1, ts2)) -> - solveEq reason ts1 ts2 + solveEq reason bcs ts1 ts2 mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () @@ -277,8 +280,8 @@ scopeViolation reason v1 ty v2 = <+> "is rigidly bound in a deeper scope." -- Precondition: 'v' is currently flexible. -subTyVar :: Reason -> VName -> Type -> SolveM () -subTyVar reason v t = do +subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () +subTyVar reason bcs v t = do occursCheck reason v t v_info <- gets $ M.lookup v . solverTyVars @@ -304,7 +307,7 @@ subTyVar reason v t = do Scalar (Sum cs2) ) -> if all (`elem` M.keys cs2) (M.keys cs1) - then unifySharedConstructors reason cs1 cs2 + then unifySharedConstructors reason bcs cs1 cs2 else typeError (locOf reason) mempty $ "Cannot unify type with constructors" @@ -323,7 +326,7 @@ subTyVar reason v t = do Scalar (Record fs2) ) -> if all (`elem` M.keys fs2) (M.keys fs1) - then unifySharedFields reason fs1 fs2 + then unifySharedFields reason bcs fs1 fs2 else typeError (locOf reason) mempty $ "Cannot unify record type with fields" @@ -352,8 +355,8 @@ subTyVar reason v t = do error $ "subTyVar: Nothing v: " <> prettyNameString v -- Precondition: 'v' and 't' are both currently flexible. -unionTyVars :: Reason -> VName -> VName -> SolveM () -unionTyVars reason v t = do +unionTyVars :: Reason -> BreadCrumbs -> VName -> VName -> SolveM () +unionTyVars reason bcs v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars t_info <- lookupTyVarInfo t @@ -411,7 +414,7 @@ unionTyVars reason v t = do ( TyVarUnsol (TyVarSum _ cs1), TyVarSum loc cs2 ) -> do - unifySharedConstructors reason cs1 cs2 + unifySharedConstructors reason bcs cs1 cs2 let cs3 = cs1 <> cs2 setInfo t (TyVarUnsol (TyVarSum loc cs3)) ( TyVarUnsol TyVarSum {}, @@ -437,7 +440,7 @@ unionTyVars reason v t = do ( TyVarUnsol (TyVarRecord _ fs1), TyVarRecord loc fs2 ) -> do - unifySharedFields reason fs1 fs2 + unifySharedFields reason bcs fs1 fs2 let fs3 = fs1 <> fs2 setInfo t (TyVarUnsol (TyVarRecord loc fs3)) ( TyVarUnsol TyVarRecord {}, @@ -481,7 +484,7 @@ unionTyVars reason v t = do isParam = error $ "Type name is a type parameter: " <> prettyNameString v -- Unify at the root, emitting new equalities that must hold. -unify :: Type -> Type -> Either (Doc a) [(Type, Type)] +unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Right [] unify @@ -490,16 +493,19 @@ unify | v1 == v2 = Right $ mapMaybe f $ zip targs1 targs2 where - f (TypeArgType t1, TypeArgType t2) = Just (t1, t2) + f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) f _ = Nothing unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Right [(t1a, t2a), (t1r', t2r')] + Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] where t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = - Right $ M.elems $ M.intersectionWith (,) fs1 fs2 + Right $ + map (first matchingField) $ + M.toList $ + M.intersectionWith (,) fs1 fs2 | otherwise = let missing = filter (`notElem` M.keys fs1) (M.keys fs2) @@ -510,19 +516,19 @@ unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = fmap concat . forM cs' $ \(ts1, ts2) -> do if length ts1 == length ts2 - then Right $ zip ts1 ts2 + then Right $ zipWith (curry (mempty,)) ts1 ts2 else Left mempty where cs' = M.elems $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = - Right [(t1', t2')] + Right [(mempty, (t1', t2'))] unify _ _ = Left mempty -solveEq :: Reason -> Type -> Type -> SolveM () -solveEq reason orig_t1 orig_t2 = do - solveCt' (orig_t1, orig_t2) +solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () +solveEq reason obcs orig_t1 orig_t2 = do + solveCt' (obcs, (orig_t1, orig_t2)) where cannotUnify details = do tyvars <- gets solverTyVars @@ -533,7 +539,7 @@ solveEq reason orig_t1 orig_t2 = do indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) details - solveCt' (t1, t2) = do + solveCt' (bcs, (t1, t2)) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of Just (Left v') -> flexible v' @@ -554,22 +560,22 @@ solveEq reason orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify mempty - (True, False) -> subTyVar reason v1 t2' - (False, True) -> subTyVar reason v2 t1' - (True, True) -> unionTyVars reason v1 v2 + (False, False) -> cannotUnify $ pretty bcs + (True, False) -> subTyVar reason bcs v1 t2' + (False, True) -> subTyVar reason bcs v2 t1' + (True, True) -> unionTyVars reason bcs v1 v2 (Scalar (TypeVar _ (QualName [] v1) []), t2') - | flexible v1 -> subTyVar reason v1 t2' + | flexible v1 -> subTyVar reason bcs v1 t2' (t1', Scalar (TypeVar _ (QualName [] v2) [])) - | flexible v2 -> subTyVar reason v2 t1' + | flexible v2 -> subTyVar reason bcs v2 t1' (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify details + Left details -> cannotUnify $ pretty bcs details Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () solveCt ct = case ct of - CtEq reason t1 t2 -> solveEq reason t1 t2 + CtEq reason t1 t2 -> solveEq reason mempty t1 t2 CtAM {} -> pure () -- Good vibes only. scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Error.hs b/src/Language/Futhark/TypeChecker/Error.hs new file mode 100644 index 0000000000..d4fbc70aad --- /dev/null +++ b/src/Language/Futhark/TypeChecker/Error.hs @@ -0,0 +1,79 @@ +-- | Fundamental facilities for constructing type error messages. +module Language.Futhark.TypeChecker.Error + ( -- * Breadcrumbs + BreadCrumbs, + hasNoBreadCrumbs, + matchingField, + matchingConstructor, + matchingTypes, + matching, + ) +where + +import Futhark.Util.Pretty +import Language.Futhark + +-- | A piece of information that describes what process the type +-- checker currently performing. This is used to give better error +-- messages for unification errors. +data BreadCrumb + = MatchingTypes StructType StructType + | MatchingFields [Name] + | MatchingConstructor Name + | Matching (Doc ()) + +instance Pretty BreadCrumb where + pretty (MatchingTypes t1 t2) = + "When matching type" + indent 2 (pretty t1) + "with" + indent 2 (pretty t2) + pretty (MatchingFields fields) = + "When matching types of record field" + <+> dquotes (mconcat $ punctuate "." $ map pretty fields) + <> dot + pretty (MatchingConstructor c) = + "When matching types of constructor" <+> dquotes (pretty c) <> dot + pretty (Matching s) = + unAnnotate s + +-- | Unification failures can occur deep down inside complicated types +-- (consider nested records). We leave breadcrumbs behind us so we can +-- report the path we took to find the mismatch. When combining +-- breadcrumbs with the 'Semigroup' instance, put the innermost +-- breadcrumbs to the left. +newtype BreadCrumbs = BreadCrumbs [BreadCrumb] + +instance Semigroup BreadCrumbs where + BreadCrumbs (MatchingFields xs : bcs1) <> BreadCrumbs (MatchingFields ys : bcs2) = + BreadCrumbs $ MatchingFields (ys <> xs) : bcs1 <> bcs2 + BreadCrumbs bcs1 <> BreadCrumbs bcs2 = + BreadCrumbs $ bcs1 <> bcs2 + +instance Monoid BreadCrumbs where + mempty = BreadCrumbs [] + +-- | Is the path empty? +hasNoBreadCrumbs :: BreadCrumbs -> Bool +hasNoBreadCrumbs (BreadCrumbs []) = True +hasNoBreadCrumbs _ = False + +-- | Matching a record field. +matchingField :: Name -> BreadCrumbs +matchingField f = BreadCrumbs [MatchingFields [f]] + +-- | Matching two types. +matchingTypes :: StructType -> StructType -> BreadCrumbs +matchingTypes t1 t2 = BreadCrumbs [MatchingTypes t1 t2] + +-- | Matching a constructor. +matchingConstructor :: Name -> BreadCrumbs +matchingConstructor c = BreadCrumbs [MatchingConstructor c] + +-- | Matching anything. +matching :: Doc () -> BreadCrumbs +matching d = BreadCrumbs [Matching d] + +instance Pretty BreadCrumbs where + pretty (BreadCrumbs []) = mempty + pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 36d9c8bade..ba30adbaeb 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -66,6 +66,7 @@ import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints (TyVar) +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 7d0f39bf91..d7898df2c1 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -10,9 +10,6 @@ module Language.Futhark.TypeChecker.Unify MonadUnify (..), Rigidity (..), RigidSource (..), - BreadCrumbs, - noBreadCrumbs, - hasNoBreadCrumbs, dimNotes, arrayElemType, normType, @@ -33,57 +30,10 @@ import Data.Set qualified as S import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark +import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Types --- | A piece of information that describes what process the type --- checker currently performing. This is used to give better error --- messages for unification errors. -data BreadCrumb - = MatchingTypes StructType StructType - | MatchingFields [Name] - | MatchingConstructor Name - | Matching (Doc ()) - -instance Pretty BreadCrumb where - pretty (MatchingTypes t1 t2) = - "When matching type" - indent 2 (pretty t1) - "with" - indent 2 (pretty t2) - pretty (MatchingFields fields) = - "When matching types of record field" - <+> dquotes (mconcat $ punctuate "." $ map pretty fields) - <> dot - pretty (MatchingConstructor c) = - "When matching types of constructor" <+> dquotes (pretty c) <> dot - pretty (Matching s) = - unAnnotate s - --- | Unification failures can occur deep down inside complicated types --- (consider nested records). We leave breadcrumbs behind us so we --- can report the path we took to find the mismatch. -newtype BreadCrumbs = BreadCrumbs [BreadCrumb] - --- | An empty path. -noBreadCrumbs :: BreadCrumbs -noBreadCrumbs = BreadCrumbs [] - --- | Is the path empty? -hasNoBreadCrumbs :: BreadCrumbs -> Bool -hasNoBreadCrumbs (BreadCrumbs xs) = null xs - --- | Drop a breadcrumb on the path behind you. -breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs -breadCrumb (MatchingFields xs) (BreadCrumbs (MatchingFields ys : bcs)) = - BreadCrumbs $ MatchingFields (ys ++ xs) : bcs -breadCrumb bc (BreadCrumbs bcs) = - BreadCrumbs $ bc : bcs - -instance Pretty BreadCrumbs where - pretty (BreadCrumbs []) = mempty - pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) - -- | A usage that caused a type constraint. data Usage = Usage (Maybe T.Text) Loc deriving (Show) @@ -387,7 +337,7 @@ unifyWith onDims usage = subunify False ) | tn == arg_tn, length targs == length arg_targs -> do - let bcs' = breadCrumb (Matching "When matching type arguments.") bcs + let bcs' = matching "When matching type arguments." <> bcs zipWithM_ (unifyTypeArg bcs') targs arg_targs ( Scalar (TypeVar _ (QualName [] v1) []), Scalar (TypeVar _ (QualName [] v2) []) @@ -439,13 +389,13 @@ unifyWith onDims usage = subunify False subunify (not ord) bound - (breadCrumb (Matching "When matching parameter types.") bcs) + (matching "When matching parameter types." <> bcs) a1 a2 subunify ord bound' - (breadCrumb (Matching "When matching return types.") bcs) + (matching "When matching return types." <> bcs) (toStruct b1') (toStruct b2') @@ -511,7 +461,7 @@ unifySizes usage bcs _ _ e1 e2 = do -- | Unifies two types. unify :: (MonadUnify m) => Usage -> StructType -> StructType -> m () -unify usage = unifyWith (unifySizes usage) usage mempty noBreadCrumbs +unify usage = unifyWith (unifySizes usage) usage mempty mempty occursCheck :: (MonadUnify m) => @@ -597,14 +547,13 @@ linkVarToType usage bound bcs vn lvl tp_unnorm = do <> " used as size(s) would go out of scope." let unliftedBcs unlifted_usage = - breadCrumb - ( Matching $ - "When verifying that" - <+> dquotes (prettyName vn) - <+> textwrap "is not instantiated with a function type, due to" - <+> pretty unlifted_usage + matching + ( "When verifying that" + <+> dquotes (prettyName vn) + <+> textwrap "is not instantiated with a function type, due to" + <+> pretty unlifted_usage ) - bcs + <> bcs case snd <$> M.lookup vn constraints of Just (NoConstraint Unlifted unlift_usage) -> do @@ -699,9 +648,7 @@ arrayElemType :: TypeBase dim u -> m () arrayElemType usage desc = - arrayElemTypeWith usage $ breadCrumb bc noBreadCrumbs - where - bc = Matching $ "When checking" <+> textwrap desc + arrayElemTypeWith usage $ matching $ "When checking" <+> textwrap desc unifySharedFields :: (MonadUnify m) => @@ -714,7 +661,7 @@ unifySharedFields :: m () unifySharedFields onDims usage bound bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (t1, t2)) -> - unifyWith onDims usage bound (breadCrumb (MatchingFields [f]) bcs) t1 t2 + unifyWith onDims usage bound (matchingField f <> bcs) t1 t2 unifySharedConstructors :: (MonadUnify m) => @@ -731,7 +678,7 @@ unifySharedConstructors onDims usage bound bcs cs1 cs2 = where unifyConstructor c f1 f2 | length f1 == length f2 = do - let bcs' = breadCrumb (MatchingConstructor c) bcs + let bcs' = matchingConstructor c <> bcs zipWithM_ (unifyWith onDims usage bound bcs') f1 f2 | otherwise = unifyError usage mempty bcs $ @@ -796,7 +743,7 @@ unifyMostCommon usage t1 t2 = do linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 onDims _ _ _ _ _ = pure () - unifyWith onDims usage mempty noBreadCrumbs t1 t2 + unifyWith onDims usage mempty mempty t1 t2 t1' <- normTypeFully t1 t2' <- normTypeFully t2 newDimOnMismatch (locOf usage) t1' t2' From e098b6d7077bc7fa63fed0df4423f530e613ca05 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 25 Sep 2024 11:32:13 +0200 Subject: [PATCH 252/296] A bit more work. --- src/Language/Futhark/TypeChecker/Constraints.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 9c9e7869c9..de304a934c 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -261,8 +261,8 @@ unifySharedFields :: M.Map Name Type -> SolveM () unifySharedFields reason bcs fs1 fs2 = - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(_f, (ts1, ts2)) -> - solveEq reason bcs ts1 ts2 + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> + solveEq reason (matchingField f <> bcs) ts1 ts2 mustSupportEql :: Reason -> Type -> SolveM () mustSupportEql _reason _t = pure () @@ -355,6 +355,9 @@ subTyVar reason bcs v t = do error $ "subTyVar: Nothing v: " <> prettyNameString v -- Precondition: 'v' and 't' are both currently flexible. +-- +-- The purpose of this function is to combine the partial knowledge we +-- may have about these two type variables. unionTyVars :: Reason -> BreadCrumbs -> VName -> VName -> SolveM () unionTyVars reason bcs v t = do v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars From 7210e8f93cb89906207159c4558c2b4cb5a73537 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 5 Oct 2024 09:50:22 +0200 Subject: [PATCH 253/296] Better error for tuple mismatches. --- src/Language/Futhark/TypeChecker/Constraints.hs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index de304a934c..51fe1140b0 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -509,6 +509,15 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) map (first matchingField) $ M.toList $ M.intersectionWith (,) fs1 fs2 + | Just n1 <- length <$> areTupleFields fs1, + Just n2 <- length <$> areTupleFields fs2, + n1 /= n2 = + Left $ + "Tuples have" + <+> pretty n1 + <+> "and" + <+> pretty n2 + <+> "elements respectively." | otherwise = let missing = filter (`notElem` M.keys fs1) (M.keys fs2) From 67cdaf99ccea330f88b9f0ebf1133738842ee1f0 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 5 Oct 2024 10:55:41 +0200 Subject: [PATCH 254/296] Better reasons. --- .../Futhark/TypeChecker/Constraints.hs | 64 +++++++++++++------ src/Language/Futhark/TypeChecker/Terms2.hs | 12 +++- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 51fe1140b0..2727340a1f 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -28,19 +28,9 @@ import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Error -import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..)) +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) import Language.Futhark.TypeChecker.Types (substTyVars) --- | The reason for a type constraint. Used to generate type error --- messages. -newtype Reason = Reason - { reasonLoc :: Loc - } - deriving (Eq, Ord, Show) - -instance Located Reason where - locOf = reasonLoc - type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -68,6 +58,22 @@ type Type = TypeBase SComp NoUniqueness toType :: TypeBase Size u -> TypeBase SComp u toType = first (const SDim) +-- | The reason for a type constraint. Used to generate type error +-- messages. +data Reason + = -- | No particular reason. + Reason Loc + | -- | Arising from pattern match. + ReasonPatMatch Loc (PatBase NoInfo VName ParamType) Type + | -- | Arising from explicit ascription. + ReasonAscription Loc Type Type + deriving (Show) + +instance Located Reason where + locOf (Reason l) = l + locOf (ReasonPatMatch l _ _) = l + locOf (ReasonAscription l _ _) = l + data Ct = CtEq Reason Type Type | CtAM Reason SVar SVar (Shape SComp) @@ -542,14 +548,32 @@ solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () solveEq reason obcs orig_t1 orig_t2 = do solveCt' (obcs, (orig_t1, orig_t2)) where - cannotUnify details = do + cannotUnify notes bcs t1 t2 = do tyvars <- gets solverTyVars - typeError (locOf reason) mempty $ - "Cannot unify" - indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t1)) - "with" - indent 2 (pretty (substTyVars (substTyVar tyvars) orig_t2)) - details + case reason of + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty (substTyVars (substTyVar tyvars) t1)), + "with", + indent 2 (pretty (substTyVars (substTyVar tyvars) t2)) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] solveCt' (bcs, (t1, t2)) = do tyvars <- gets solverTyVars @@ -572,7 +596,7 @@ solveEq reason obcs orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify $ pretty bcs + (False, False) -> cannotUnify mempty bcs t1 t2 (True, False) -> subTyVar reason bcs v1 t2' (False, True) -> subTyVar reason bcs v2 t1' (True, True) -> unionTyVars reason bcs v1 v2 @@ -581,7 +605,7 @@ solveEq reason obcs orig_t1 orig_t2 = do (t1', Scalar (TypeVar _ (QualName [] v2) [])) | flexible v2 -> subTyVar reason bcs v2 t1' (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify $ pretty bcs details + Left details -> cannotUnify (aNote details) bcs t1' t2' Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 8fabb67962..79013c727e 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -510,7 +510,10 @@ checkPat' (TuplePat ps loc) (Ascribed t) <*> pure loc | otherwise = do ps_tvs <- replicateM (length ps) (newTyVar loc Lifted "t") - ctEq (Reason (locOf loc)) (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) t + ctEq + (ReasonPatMatch (locOf loc) (TuplePat ps loc) (toStruct t)) + (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) + t TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc @@ -551,7 +554,10 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do case maybe_outer_t of Ascribed outer_t -> do - ctEq (Reason (locOf loc)) st' outer_t + ctEq + (ReasonAscription (locOf loc) (toStruct st') (toStruct outer_t)) + st' + outer_t PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -1238,7 +1244,7 @@ checkExp (Ascript e te loc) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te e_t <- expType e' st' <- asType st - ctEq (Reason (locOf e')) e_t st' + ctEq (ReasonAscription (locOf e') (toStruct st') (toStruct e_t)) e_t st' pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do e' <- checkExp e From 59ddb21561918612f8a5fa21cc900a9eb9e49d0d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 6 Oct 2024 10:24:46 +0200 Subject: [PATCH 255/296] Constructor match. --- src/Language/Futhark/TypeChecker/Constraints.hs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 2727340a1f..4835c3fd2b 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -251,7 +251,7 @@ unifySharedConstructors :: unifySharedConstructors reason bcs cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> if length ts1 == length ts2 - then zipWithM_ (solveEq reason bcs) ts1 ts2 + then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 else typeError (locOf reason) mempty $ "Cannot unify type with constructor" @@ -532,12 +532,12 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) "Unshared fields:" <+> commasep (map pretty missing) <> "." unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = - fmap concat . forM cs' $ \(ts1, ts2) -> do + fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do if length ts1 == length ts2 - then Right $ zipWith (curry (mempty,)) ts1 ts2 + then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 else Left mempty where - cs' = M.elems $ M.intersectionWith (,) cs1 cs2 + cs' = M.toList $ M.intersectionWith (,) cs1 cs2 unify t1 t2 | Just t1' <- peelArray 1 t1, Just t2' <- peelArray 1 t2 = From 946ae65868937a25f988c053ecb28cafa14c3cc4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 6 Oct 2024 10:51:40 +0200 Subject: [PATCH 256/296] More reasons. --- .../Futhark/TypeChecker/Constraints.hs | 87 ++++++++++++------- src/Language/Futhark/TypeChecker/Terms2.hs | 8 +- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 4835c3fd2b..1075c16823 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -59,7 +59,7 @@ toType :: TypeBase Size u -> TypeBase SComp u toType = first (const SDim) -- | The reason for a type constraint. Used to generate type error --- messages. +-- messages. The expected type is always the first one. data Reason = -- | No particular reason. Reason Loc @@ -67,12 +67,14 @@ data Reason ReasonPatMatch Loc (PatBase NoInfo VName ParamType) Type | -- | Arising from explicit ascription. ReasonAscription Loc Type Type - deriving (Show) + | ReasonRetType Loc Type Type + deriving (Eq, Show) instance Located Reason where locOf (Reason l) = l locOf (ReasonPatMatch l _ _) = l locOf (ReasonAscription l _ _) = l + locOf (ReasonRetType l _ _) = l data Ct = CtEq Reason Type Type @@ -227,6 +229,12 @@ solution s = newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) +-- Try to substitute as much information as we have. +enrichType :: Type -> SolveM Type +enrichType t = do + s <- get + pure $ substTyVars (substTyVar (solverTyVars s)) t + typeError :: Loc -> Notes -> Doc () -> SolveM () typeError loc notes msg = throwError $ TypeError loc notes msg @@ -285,6 +293,50 @@ scopeViolation reason v1 ty v2 = <+> dquotes (prettyName v2) <+> "is rigidly bound in a deeper scope." +cannotUnify :: + Reason -> + Notes -> + BreadCrumbs -> + Type -> + Type -> + SolveM () +cannotUnify reason notes bcs t1 t2 = do + t1' <- enrichType t1 + t2' <- enrichType t2 + case reason of + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonRetType loc expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ "Function body does not have expected type.", + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () subTyVar reason bcs v t = do @@ -548,33 +600,6 @@ solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () solveEq reason obcs orig_t1 orig_t2 = do solveCt' (obcs, (orig_t1, orig_t2)) where - cannotUnify notes bcs t1 t2 = do - tyvars <- gets solverTyVars - case reason of - ReasonPatMatch loc pat value_t -> - typeError loc notes . stack $ - [ "Pattern", - indent 2 $ align $ pretty pat, - "cannot match value of type", - indent 2 $ align $ pretty value_t - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - ReasonAscription loc expected actual -> - typeError loc notes . stack $ - [ "Expression does not have expected type from type ascription.", - "Expected:" <+> align (pretty expected), - "Actual: " <+> align (pretty actual) - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - Reason loc -> - typeError loc notes . stack $ - [ "Cannot unify", - indent 2 (pretty (substTyVars (substTyVar tyvars) t1)), - "with", - indent 2 (pretty (substTyVars (substTyVar tyvars) t2)) - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - solveCt' (bcs, (t1, t2)) = do tyvars <- gets solverTyVars let flexible v = case M.lookup v tyvars of @@ -596,7 +621,7 @@ solveEq reason obcs orig_t1 orig_t2 = do | v1 == v2 -> pure () | otherwise -> case (flexible v1, flexible v2) of - (False, False) -> cannotUnify mempty bcs t1 t2 + (False, False) -> cannotUnify reason mempty bcs t1 t2 (True, False) -> subTyVar reason bcs v1 t2' (False, True) -> subTyVar reason bcs v2 t1' (True, True) -> unionTyVars reason bcs v1 v2 @@ -605,7 +630,7 @@ solveEq reason obcs orig_t1 orig_t2 = do (t1', Scalar (TypeVar _ (QualName [] v2) [])) | flexible v2 -> subTyVar reason bcs v2 t1' (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify (aNote details) bcs t1' t2' + Left details -> cannotUnify reason (aNote details) bcs t1' t2' Right eqs -> mapM_ solveCt' eqs solveCt :: Ct -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 79013c727e..40e26b7107 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -840,9 +840,9 @@ checkRetDecl body Nothing = (,Nothing) <$> expType body checkRetDecl body (Just te) = do (te', _, RetType _ st, _) <- checkTypeExp checkSizeExp' te body_t <- expType body - st' <- asType st - ctEq (Reason (locOf body)) body_t st' - pure (second (const NoUniqueness) st', Just te') + st' <- toStruct <$> asType st + ctEq (ReasonRetType (locOf body) st' body_t) st' body_t + pure (st', Just te') checkExp :: ExpBase NoInfo VName -> TermM (ExpBase Info VName) -- @@ -1372,7 +1372,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do bitraverse pure (fmap (second (onArtificial artificial)) . onTySolution params' body'') - $ solve cts' typarams tyvars' + $ solve (reverse cts') typarams tyvars' debugTraceM 3 $ unlines [ "## constraints:", From fba57820538e155f5a06a6f7b8b75dfb076c661d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 13 Oct 2024 18:55:48 +0200 Subject: [PATCH 257/296] Function application reason. --- .../Futhark/TypeChecker/Constraints.hs | 35 +++++++++-- src/Language/Futhark/TypeChecker/Terms2.hs | 59 ++++++++++++------- 2 files changed, 66 insertions(+), 28 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1075c16823..6ed3b64a84 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -68,6 +68,7 @@ data Reason | -- | Arising from explicit ascription. ReasonAscription Loc Type Type | ReasonRetType Loc Type Type + | ReasonApply Loc (Maybe (QualName VName)) Exp Type Type deriving (Eq, Show) instance Located Reason where @@ -75,6 +76,7 @@ instance Located Reason where locOf (ReasonPatMatch l _ _) = l locOf (ReasonAscription l _ _) = l locOf (ReasonRetType l _ _) = l + locOf (ReasonApply l _ _ _ _) = l data Ct = CtEq Reason Type Type @@ -304,6 +306,14 @@ cannotUnify reason notes bcs t1 t2 = do t1' <- enrichType t1 t2' <- enrichType t2 case reason of + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] ReasonPatMatch loc pat value_t -> typeError loc notes . stack $ [ "Pattern", @@ -328,14 +338,27 @@ cannotUnify reason notes bcs t1 t2 = do "Actual: " <+> align (pretty actual') ] <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - Reason loc -> + ReasonApply loc f e expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual typeError loc notes . stack $ - [ "Cannot unify", - indent 2 (pretty t1'), - "with", - indent 2 (pretty t2') + [ header, + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + where + header = + case f of + Nothing -> + "Cannot apply function to" + <+> dquotes (shorten $ group $ pretty e) + <> " (invalid type)." + Just fname -> + "Cannot apply" + <+> dquotes (pretty fname) + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> " (invalid type)." -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 40e26b7107..89b5320a67 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -678,8 +678,13 @@ bindParams tps orig_ps m = bindTypeParams tps $ do incLevel $ descend [] orig_ps -checkApplyOne :: SrcLoc -> (Maybe (QualName VName), Int) -> (Shape Size, Type) -> (Shape Size, Type) -> TermM (Type, AutoMap) -checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do +checkApplyOne :: + SrcLoc -> + (Maybe (QualName VName), Int) -> + (Shape Size, Type) -> + (Maybe Exp, Shape Size, Type) -> + TermM (Type, AutoMap) +checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do (a, b) <- split ftype r <- newSVar loc "R" m <- newSVar loc "M" @@ -689,7 +694,11 @@ checkApplyOne loc fname (fframe, ftype) (argframe, argtype) = do lhs = arrayOf (toShape (SVar r)) argtype rhs = arrayOf (toShape (SVar m)) a ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) - ctEq (Reason (locOf loc)) lhs rhs + let reason = case arg of + Just arg' -> + ReasonApply (locOf loc) (fst fname) arg' lhs rhs + Nothing -> Reason (locOf loc) + ctEq reason lhs rhs debugTraceM 3 $ unlines [ "## checkApplyOne", @@ -737,14 +746,14 @@ checkApply :: SrcLoc -> Maybe (QualName VName) -> (Shape Size, Type) -> - NE.NonEmpty (Shape Size, Type) -> + NE.NonEmpty (Maybe Exp, Shape Size, Type) -> TermM (Type, NE.NonEmpty AutoMap) checkApply loc fname (fframe, ftype) args = do ((_, _, rt), argts) <- mapAccumLM onArg (0, fframe, ftype) args pure (rt, argts) where - onArg (i, f_f, f_t) (argframe, argtype) = do - (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) (argframe, argtype) + onArg (i, f_f, f_t) arg = do + (rt, am) <- checkApplyOne loc (fname, i) (f_f, f_t) arg pure ( (i + 1, autoFrame am, rt), am @@ -943,21 +952,17 @@ checkExp (Constr name es NoInfo loc) = do -- checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe - (args', argts') <- - NE.unzip - <$> forM - args - ( \(_, arg) -> do - arg' <- checkExp arg - arg_t <- expType arg' - pure (arg', (frameOf arg', arg_t)) - ) + (args', apply_args) <- + fmap NE.unzip . forM args $ \(_, arg) -> do + arg' <- checkExp arg + arg_t <- expType arg' + pure (arg', (Just arg', frameOf arg', arg_t)) fe_t <- expType fe' - (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) argts' + (rt, ams) <- checkApply loc fname (frameOf fe', fe_t) apply_args rt' <- asStructType rt - pure $ - AppExp (Apply fe' (NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args') loc) $ - Info (AppRes rt' []) + let args'' = + NE.zipWith (\am arg -> (Info (Nothing, am), arg)) ams args' + pure $ AppExp (Apply fe' args'' loc) $ Info (AppRes rt' []) where fname = case fe of @@ -975,7 +980,7 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do loc (Just op) (mempty, ftype) - ((frameOf e1', e1_t) NE.:| [(frameOf e2', e2_t)]) + ((Just e1', frameOf e1', e1_t) NE.:| [(Just e2', frameOf e2', e2_t)]) rt' <- asStructType rt let (am1 NE.:| [am2]) = ams @@ -992,7 +997,12 @@ checkExp (OpSectionLeft op _ e _ _ loc) = do t2 <- newType loc Lifted "t" NoUniqueness t2' <- asStructType t2 let f1 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((f1, e_t) NE.:| [(mempty, t2)]) + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, optype) + ((Just e', f1, e_t) NE.:| [(Nothing, mempty, t2)]) rt' <- asStructType rt let (am1 NE.:| _) = ams @@ -1015,7 +1025,12 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do t1 <- newType loc Lifted "t" NoUniqueness t1' <- asStructType t1 let f2 = frameOf e' - (rt, ams) <- checkApply loc (Just op) (mempty, optype) ((mempty, t1) NE.:| [(f2, e_t)]) + (rt, ams) <- + checkApply + loc + (Just op) + (mempty, optype) + ((Nothing, mempty, t1) NE.:| [(Just e', f2, e_t)]) rt' <- asStructType rt let (_ NE.:| [am2]) = ams t2 <- asStructType e_t From 29293af45b6d9e33eddf5ef5bf53ca5c1d39ab3d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 13 Oct 2024 19:17:39 +0200 Subject: [PATCH 258/296] Branch reasons. --- src/Language/Futhark/TypeChecker/Constraints.hs | 10 ++++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 6ed3b64a84..c8ceb92763 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -69,6 +69,7 @@ data Reason ReasonAscription Loc Type Type | ReasonRetType Loc Type Type | ReasonApply Loc (Maybe (QualName VName)) Exp Type Type + | ReasonBranches Loc Type Type deriving (Eq, Show) instance Located Reason where @@ -77,6 +78,7 @@ instance Located Reason where locOf (ReasonAscription l _ _) = l locOf (ReasonRetType l _ _) = l locOf (ReasonApply l _ _ _ _) = l + locOf (ReasonBranches l _ _) = l data Ct = CtEq Reason Type Type @@ -359,6 +361,14 @@ cannotUnify reason notes bcs t1 t2 = do <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) <> " (invalid type)." + ReasonBranches loc former latter -> do + former' <- enrichType former + latter' <- enrichType latter + typeError loc notes . stack $ + [ "Branches differ in type.", + "Former:" <+> pretty former', + "Latter:" <+> pretty latter' + ] -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 89b5320a67..41594c4282 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -810,7 +810,7 @@ checkCases mt rest_cs = (c, Just cs) -> do (c', c_t) <- checkCase mt c (cs', cs_t) <- checkCases mt cs - ctEq (Reason (locOf c)) c_t cs_t + ctEq (ReasonBranches (locOf c) c_t cs_t) c_t cs_t pure (NE.cons c' cs', c_t) -- | An unmatched pattern. Used in in the generation of @@ -1203,8 +1203,8 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = do if_t <- newType loc SizeLifted "if_t" NoUniqueness ctEq (Reason (locOf e1')) e1_t (Scalar (Prim Bool)) - ctEq (Reason (locOf loc)) e2_t if_t - ctEq (Reason (locOf loc)) e3_t if_t + ctEq (ReasonBranches (locOf loc) e2_t e3_t) e2_t if_t + ctEq (ReasonBranches (locOf loc) e2_t e3_t) e3_t if_t if_t' <- asStructType if_t pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes if_t' []) From d068577e33983e9c26fd1c817ac014627c5dbdcc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 22 Dec 2024 21:29:26 +0100 Subject: [PATCH 259/296] Error message fixes. --- .../Futhark/TypeChecker/Constraints.hs | 25 ++++++++++++------- src/Language/Futhark/TypeChecker/Terms2.hs | 3 +-- .../conditional-function0.fut | 8 +++--- tests/higher-order-functions/loops0.fut | 11 ++++---- .../match-function0.fut | 8 +++--- tests/issue514.fut | 2 +- tests/record-update1.fut | 4 +-- tests/records-error3.fut | 6 ++--- tests/records-error4.fut | 6 ++--- tests/records-error7.fut | 2 +- tests/records-error8.fut | 2 +- tests/records-error9.fut | 2 +- 12 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index c8ceb92763..5a4586ce25 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -213,6 +213,9 @@ type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -- a constraint on how it can be instantiated. type UnconTyVar = (VName, Liftedness) +typeVar :: (Monoid u) => VName -> TypeBase dim u +typeVar v = Scalar $ TypeVar mempty (qualName v) [] + solution :: SolverState -> ([UnconTyVar], Solution) solution s = ( mapMaybe unconstrained $ M.toList $ solverTyVars s, @@ -388,9 +391,11 @@ subTyVar reason bcs v t = do ) -> if t `elem` map (Scalar . Prim) v_pts then pure () - else - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot instance type that must be one of" indent 2 (pretty v_pts) "with" indent 2 (pretty t) @@ -399,12 +404,14 @@ subTyVar reason bcs v t = do ) -> if all (`elem` M.keys cs2) (M.keys cs1) then unifySharedConstructors reason bcs cs1 cs2 - else - typeError (locOf reason) mempty $ - "Cannot unify type with constructors" - indent 2 (pretty (Sum cs1)) - "with type" - indent 2 (pretty (Sum cs2)) + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot match type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) + "with type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), _ ) -> diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index a5895e708e..ce85a83576 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -698,8 +698,7 @@ checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do rhs = arrayOf (toShape (SVar m)) a ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) let reason = case arg of - Just arg' -> - ReasonApply (locOf loc) (fst fname) arg' lhs rhs + Just arg' -> ReasonApply (locOf arg) (fst fname) arg' lhs rhs Nothing -> Reason (locOf loc) ctEq reason lhs rhs debugTraceM 3 $ diff --git a/tests/higher-order-functions/conditional-function0.fut b/tests/higher-order-functions/conditional-function0.fut index 433b393c39..00f1b441bd 100644 --- a/tests/higher-order-functions/conditional-function0.fut +++ b/tests/higher-order-functions/conditional-function0.fut @@ -1,8 +1,8 @@ -- We cannot return a function from a conditional. -- == --- error: returned from branch +-- error: may not be of function type -def f (x:i32) : i32 = x+x -def g (x:i32) : i32 = x+1 +def f (x: i32) : i32 = x + x +def g (x: i32) : i32 = x + 1 -def main (b : bool) (n : i32) : i32 = (if b then f else g) n +def main (b: bool) (n: i32) : i32 = (if b then f else g) n diff --git a/tests/higher-order-functions/loops0.fut b/tests/higher-order-functions/loops0.fut index 1e8e31aa26..8001a3a7c0 100644 --- a/tests/higher-order-functions/loops0.fut +++ b/tests/higher-order-functions/loops0.fut @@ -1,9 +1,10 @@ -- The merge parameter in a loop cannot have function type. -- == --- error: used as loop variable +-- error: may not be of function -def id 'a (x : a) : a = x +def id 'a (x: a) : a = x -def main (n : i32) = - loop f = id for i < n do - \(y:i32) -> f y +def main (n: i32) = + loop f = id + for i < n do + \(y: i32) -> f y diff --git a/tests/higher-order-functions/match-function0.fut b/tests/higher-order-functions/match-function0.fut index de8ec163c8..30d43c8736 100644 --- a/tests/higher-order-functions/match-function0.fut +++ b/tests/higher-order-functions/match-function0.fut @@ -1,8 +1,8 @@ -- We cannot return a function from a pattern match. -- == --- error: returned from pattern match +-- error: may not be of function type -def f (x:i32) : i32 = x+x -def g (x:i32) : i32 = x+1 +def f (x: i32) : i32 = x + x +def g (x: i32) : i32 = x + 1 -def main (b : bool) (n : i32) : i32 = (match b case _ -> f) n +def main (b: bool) (n: i32) : i32 = (match b case _ -> f) n diff --git a/tests/issue514.fut b/tests/issue514.fut index 057d69b71a..2f70eca04f 100644 --- a/tests/issue514.fut +++ b/tests/issue514.fut @@ -1,4 +1,4 @@ -- == --- error: issue514.fut:4:13-22 +-- error: issue514.fut:4:26-36 def main = (2.0 + 3.0) / (2 + 3i32) diff --git a/tests/record-update1.fut b/tests/record-update1.fut index a0bcf524f9..d4a24afc99 100644 --- a/tests/record-update1.fut +++ b/tests/record-update1.fut @@ -1,8 +1,8 @@ -- Type-changing record update. -- == --- error: i32.*bool +-- error: bool.*i32 -def main (x: i32) (y: i32): (bool, i32) = +def main (x: i32) (y: i32) : (bool, i32) = let r0 = {x, y} let r1 = r0 with x = true in (r1.x, r1.y) diff --git a/tests/records-error3.fut b/tests/records-error3.fut index 36c0497e69..4261d1b035 100644 --- a/tests/records-error3.fut +++ b/tests/records-error3.fut @@ -1,8 +1,8 @@ -- A record value must have at least the fields of its corresponding -- type. -- == --- error: match +-- error: unshared fields -def main() = - let r:{a:i32,b:i32} = {a=0} +def main () = + let r: {a: i32, b: i32} = {a = 0} in 0 diff --git a/tests/records-error4.fut b/tests/records-error4.fut index 229ae3aef4..7463fe7637 100644 --- a/tests/records-error4.fut +++ b/tests/records-error4.fut @@ -1,8 +1,8 @@ -- A record value must not have more fields than its corresponding -- type. -- == --- error: match +-- error: Unshared fields -def main() = - let r:{a:i32} = {a=0,b=0} +def main () = + let r: {a: i32} = {a = 0, b = 0} in 0 diff --git a/tests/records-error7.fut b/tests/records-error7.fut index 8d397ebaf2..3d98d992cd 100644 --- a/tests/records-error7.fut +++ b/tests/records-error7.fut @@ -1,5 +1,5 @@ -- Specific error message on record field mismatches. -- == --- error: Unshared fields: d, c. +-- error: Unshared fields: c, d. def f (v: {a: i32, b: i32, c: i32}) : {a: i32, b: i32, d: i32} = v diff --git a/tests/records-error8.fut b/tests/records-error8.fut index d23c668f7f..c872d18d18 100644 --- a/tests/records-error8.fut +++ b/tests/records-error8.fut @@ -1,6 +1,6 @@ -- Unification of variables with incompletely known and distinct fields. -- == --- error: must be a record with fields +-- error: Cannot unify record type def sameconst '^a (_: a) (y: a) = y diff --git a/tests/records-error9.fut b/tests/records-error9.fut index 086ee3f9f0..7515695a02 100644 --- a/tests/records-error9.fut +++ b/tests/records-error9.fut @@ -1,6 +1,6 @@ -- Unification of incomplete record variable with non-record. -- == --- error: Cannot unify a record type with a non-record type +-- error: with type that must be a record def sameconst '^a (_: a) (y: a) = y From 8b0433f9a80007b0f0befae2fb5450975f711fe3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 22 Dec 2024 21:48:24 +0100 Subject: [PATCH 260/296] Fix more error messages. --- src/Language/Futhark/TypeChecker/Constraints.hs | 9 +++++++++ tests/sumtypes/sumtype46.fut | 2 +- tests/sumtypes/sumtype48.fut | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 5a4586ce25..8b8355c12a 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -373,6 +373,14 @@ cannotUnify reason notes bcs t1 t2 = do "Latter:" <+> pretty latter' ] +unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a +unsharedConstructorsMsg cs1 cs2 = + "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." + where + missing = + filter (`notElem` M.keys cs1) (M.keys cs2) + ++ filter (`notElem` M.keys cs2) (M.keys cs1) + -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () subTyVar reason bcs v t = do @@ -412,6 +420,7 @@ subTyVar reason bcs v t = do indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) "with type with constructors" indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) + unsharedConstructorsMsg cs1 cs2 ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), _ ) -> diff --git a/tests/sumtypes/sumtype46.fut b/tests/sumtypes/sumtype46.fut index dac4e4671f..a56af88750 100644 --- a/tests/sumtypes/sumtype46.fut +++ b/tests/sumtypes/sumtype46.fut @@ -1,5 +1,5 @@ -- == --- error: cannot match +-- error: 0 constructor arguments type t = #foo f64 diff --git a/tests/sumtypes/sumtype48.fut b/tests/sumtypes/sumtype48.fut index ad614f03cf..6c184e9239 100644 --- a/tests/sumtypes/sumtype48.fut +++ b/tests/sumtypes/sumtype48.fut @@ -3,4 +3,4 @@ type t = #foo | #bar -let f b : t = if b then #foo else #baar +def f b : t = if b then #foo else #baar From 7b57df6a58f622a31f661e66974c898a9ec71187 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 22 Dec 2024 22:10:50 +0100 Subject: [PATCH 261/296] More work on errors. --- .../Futhark/TypeChecker/Constraints.hs | 25 +++++++++++-------- tests/sumtypes/sumtype32.fut | 2 +- tests/sumtypes/sumtype47.fut | 2 +- tests/types/inference-error12.fut | 7 +++--- tests/types/inference-error3.fut | 2 +- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 8b8355c12a..1ad98997ad 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -373,14 +373,6 @@ cannotUnify reason notes bcs t1 t2 = do "Latter:" <+> pretty latter' ] -unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a -unsharedConstructorsMsg cs1 cs2 = - "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." - where - missing = - filter (`notElem` M.keys cs1) (M.keys cs2) - ++ filter (`notElem` M.keys cs2) (M.keys cs1) - -- Precondition: 'v' is currently flexible. subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () subTyVar reason bcs v t = do @@ -511,7 +503,7 @@ unionTyVars reason bcs v t = do typeError (locOf reason) mempty $ "Cannot unify type that must be one of" indent 2 (pretty v_pts) - "with type that must be record." + "with type that must be a record." ( TyVarUnsol (TyVarPrim _ v_pts), TyVarSum {} ) -> @@ -593,6 +585,14 @@ unionTyVars reason bcs v t = do alreadySolved = error $ "Type variable already solved: " <> prettyNameString v isParam = error $ "Type name is a type parameter: " <> prettyNameString v +unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a +unsharedConstructorsMsg cs1 cs2 = + "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." + where + missing = + filter (`notElem` M.keys cs1) (M.keys cs2) + ++ filter (`notElem` M.keys cs2) (M.keys cs1) + -- Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) @@ -630,13 +630,15 @@ unify (Scalar (Record fs1)) (Scalar (Record fs2)) filter (`notElem` M.keys fs1) (M.keys fs2) <> filter (`notElem` M.keys fs2) (M.keys fs1) in Left $ - "Unshared fields:" <+> commasep (map pretty missing) <> "." + "unshared fields:" <+> commasep (map pretty missing) <> "." unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do if length ts1 == length ts2 then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 else Left mempty + | otherwise = + Left $ unsharedConstructorsMsg cs1 cs2 where cs' = M.toList $ M.intersectionWith (,) cs1 cs2 unify t1 t2 @@ -773,6 +775,9 @@ solveTyVar (tv, (lvl, TyVarFree loc l)) = do solveTyVar (tv, (_, TyVarPrim loc pts)) = do tv_t <- lookupTyVar tv case tv_t of + Right (Scalar (Prim ty)) + | [ty] == pts -> + setInfo tv $ TyVarSol $ Scalar $ Prim ty Right ty | ty `elem` map (Scalar . Prim) pts -> pure () | otherwise -> diff --git a/tests/sumtypes/sumtype32.fut b/tests/sumtypes/sumtype32.fut index ca5f22c91f..b5255b580c 100644 --- a/tests/sumtypes/sumtype32.fut +++ b/tests/sumtypes/sumtype32.fut @@ -1,5 +1,5 @@ -- Specific error message on constructor mismatches. -- == --- error: Unshared constructors: #d, #c. +-- error: Unshared constructors: #c, #d. def f (v: #a i32 | #b i32 | #c i32) : #a i32 | #b i32 | #d i32 = v diff --git a/tests/sumtypes/sumtype47.fut b/tests/sumtypes/sumtype47.fut index b3fb9c7ce5..09a51f3faa 100644 --- a/tests/sumtypes/sumtype47.fut +++ b/tests/sumtypes/sumtype47.fut @@ -1,5 +1,5 @@ -- == --- error: cannot match +-- error: 2 constructor arguments type t = #foo f64 diff --git a/tests/types/inference-error12.fut b/tests/types/inference-error12.fut index 5e8dc8e3d8..01e499f875 100644 --- a/tests/types/inference-error12.fut +++ b/tests/types/inference-error12.fut @@ -1,6 +1,7 @@ -- A record turns out to be missing a field. -- == --- error: expected type +-- error: unify record type -def f r = let y = r.l2 - in (r: {l1: i32}) +def f r = + let y = r.l2 + in (r : {l1: i32}) diff --git a/tests/types/inference-error3.fut b/tests/types/inference-error3.fut index 7a2da29bd6..6eda032d42 100644 --- a/tests/types/inference-error3.fut +++ b/tests/types/inference-error3.fut @@ -1,5 +1,5 @@ -- If something is applied, it cannot later be put in an array. -- == --- error: -> b +-- error: -> def f x = (x 2, [x]) From 845e340f69a603329bde90baeaa885d59ee52a5c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 22 Dec 2024 22:27:58 +0100 Subject: [PATCH 262/296] Fix warnings about overloading. --- src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index ce85a83576..975b7ae9e1 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -1337,7 +1337,7 @@ generaliseAndDefaults :: generaliseAndDefaults unconstrained solution t = do let (generalised, unconstrained') = generalise t unconstrained solution - solution' <- doDefaults (map typeParamName generalised) solution + solution' <- doDefaults (S.toList $ typeVars t) solution pure ( generalised, -- See #1552 for why we resolve unconstrained and From 6d541a0cd14cbda6905774322c3f96db0015ab15 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 22 Dec 2024 22:43:35 +0100 Subject: [PATCH 263/296] Fix more things. --- tests/issue1783.fut | 7 ++++--- tests/records-error4.fut | 2 +- tests/records-error7.fut | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/issue1783.fut b/tests/issue1783.fut index b585962c07..0b8e197910 100644 --- a/tests/issue1783.fut +++ b/tests/issue1783.fut @@ -1,8 +1,9 @@ -- == --- error: cannot match +-- error: constructor arguments -type surface = #asphere {curvature: f64} - | #sphere {curvature: f64} +type surface = + #asphere {curvature: f64} + | #sphere {curvature: f64} entry sag (surf: surface) : f64 = match surf diff --git a/tests/records-error4.fut b/tests/records-error4.fut index 7463fe7637..c21797bbc1 100644 --- a/tests/records-error4.fut +++ b/tests/records-error4.fut @@ -1,7 +1,7 @@ -- A record value must not have more fields than its corresponding -- type. -- == --- error: Unshared fields +-- error: unshared fields def main () = let r: {a: i32} = {a = 0, b = 0} diff --git a/tests/records-error7.fut b/tests/records-error7.fut index 3d98d992cd..122d539878 100644 --- a/tests/records-error7.fut +++ b/tests/records-error7.fut @@ -1,5 +1,5 @@ -- Specific error message on record field mismatches. -- == --- error: Unshared fields: c, d. +-- error: unshared fields: c, d. def f (v: {a: i32, b: i32, c: i32}) : {a: i32, b: i32, d: i32} = v From 4f88fef36587c6b02edd7865d5b9bf0f0752fa63 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 24 Dec 2024 00:09:28 +0100 Subject: [PATCH 264/296] Eliminate TyVarEql. --- .../Futhark/TypeChecker/Constraints.hs | 46 ------------------- src/Language/Futhark/TypeChecker/Rank.hs | 3 -- src/Language/Futhark/TypeChecker/Terms.hs | 21 +++++++++ src/Language/Futhark/TypeChecker/Terms2.hs | 2 +- tests/types/inference-error9.fut | 2 +- 5 files changed, 23 insertions(+), 51 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1ad98997ad..b6d186b195 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -110,8 +110,6 @@ data TyVarInfo TyVarRecord Loc (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum Loc (M.Map Name [Type]) - | -- | Must be a type that supports equality. - TyVarEql Loc deriving (Show, Eq) instance Pretty TyVarInfo where @@ -119,14 +117,12 @@ instance Pretty TyVarInfo where pretty (TyVarPrim _ pts) = "∈" <+> pretty pts pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs - pretty (TyVarEql _) = "equality" instance Located TyVarInfo where locOf (TyVarFree loc _) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc - locOf (TyVarEql loc) = loc type TyVar = VName @@ -285,9 +281,6 @@ unifySharedFields reason bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> solveEq reason (matchingField f <> bcs) ts1 ts2 -mustSupportEql :: Reason -> Type -> SolveM () -mustSupportEql _reason _t = pure () - scopeViolation :: Reason -> VName -> Type -> VName -> SolveM () scopeViolation reason v1 ty v2 = typeError (locOf reason) mempty $ @@ -440,8 +433,6 @@ subTyVar reason bcs v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty t) - (Just (Right (TyVarUnsol (TyVarEql _))), _) -> - mustSupportEql reason t -- -- Internal error cases (Just (Right TyVarSol {}), _) -> @@ -481,10 +472,6 @@ unionTyVars reason bcs v t = do setInfo t (TyVarUnsol info) -- -- TyVarPrim cases - ( TyVarUnsol info@TyVarPrim {}, - TyVarEql {} - ) -> - setInfo t (TyVarUnsol info) ( TyVarUnsol (TyVarPrim _ v_pts), TyVarPrim t_loc t_pts ) -> @@ -533,10 +520,6 @@ unionTyVars reason bcs v t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Scalar (Record fs))) - ( TyVarUnsol (TyVarSum _ cs1), - TyVarEql _ - ) -> - mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases ( TyVarUnsol (TyVarRecord _ fs1), @@ -559,20 +542,6 @@ unionTyVars reason bcs v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty (Scalar (Sum cs))) - ( TyVarUnsol (TyVarRecord _ fs1), - TyVarEql _ - ) -> - mapM_ (mustSupportEql reason) fs1 - -- - -- TyVarEql cases - (TyVarUnsol (TyVarEql _), TyVarPrim {}) -> - pure () - (TyVarUnsol (TyVarEql _), TyVarEql {}) -> - pure () - (TyVarUnsol (TyVarEql _), TyVarRecord _ fs) -> - mustSupportEql reason $ Scalar $ Record fs - (TyVarUnsol (TyVarEql _), TyVarSum _ cs) -> - mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases (TyVarSol {}, _) -> @@ -750,21 +719,6 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () -solveTyVar (tv, (_, TyVarEql loc)) = do - tv_t <- lookupTyVar tv - case tv_t of - Left TyVarEql {} -> - typeError loc mempty $ - "Type is ambiguous (must be equality type)" - "Add a type annotation to disambiguate the type." - Left _ -> pure () - Right ty - | orderZero ty -> pure () - | otherwise -> - typeError loc mempty $ - "Type" - indent 2 (align (pretty ty)) - "does not support equality (may contain function)." solveTyVar (tv, (lvl, TyVarFree loc l)) = do tv_t <- lookupTyVar tv case tv_t of diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 24254d7392..deb578facb 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -167,8 +167,6 @@ addTyVarInfo tv (_, TyVarRecord {}) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarEql {}) = - addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: [Ct] -> TyVars -> LinearProg mkLinearProg cs tyVars = @@ -454,7 +452,6 @@ instance SubstRanks TyVarInfo where TyVarRecord loc <$> traverse substRanks fs substRanks (TyVarSum loc cs) = TyVarSum loc <$> (traverse . traverse) substRanks cs - substRanks tv@TyVarEql {} = pure tv instance SubstRanks (Int, TyVarInfo) where substRanks (lvl, tv) = (lvl,) <$> substRanks tv diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 48f82ad5ea..b6009887ed 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1143,6 +1143,14 @@ mustBeIrrefutable p = do "Refutable pattern not allowed here.\nUnmatched cases:" indent 2 (stack (map pretty ps')) +supportsEquality :: TypeBase dim u -> Bool +supportsEquality (Array _ _ t) = supportsEquality $ Scalar t +supportsEquality (Scalar Prim {}) = True +supportsEquality (Scalar TypeVar {}) = False +supportsEquality (Scalar (Record fs)) = all supportsEquality fs +supportsEquality (Scalar (Sum fs)) = all (all supportsEquality) fs +supportsEquality (Scalar Arrow {}) = False + -- | Traverse the expression, emitting warnings and errors for various -- problems: -- @@ -1163,6 +1171,12 @@ localChecks = void . check indent 2 (stack (map pretty ps')) check e@(AppExp (LetPat _ p _ _ _) _) = mustBeIrrefutable p *> recurse e + check e@(AppExp (BinOp (v, loc) _ (x, _) _ _) _) + | qualLeaf v == intrinsicVar "==" = + checkEquality loc (typeOf x) *> recurse e + check e@(Var v (Info t) loc) + | qualLeaf v == intrinsicVar "==" = + checkEquality loc t *> recurse e check e@(Lambda ps _ _ _ _) = mapM_ (mustBeIrrefutable . fmap toStruct) ps *> recurse e check e@(AppExp (LetFun _ (_, ps, _, _, _) _ _) _) = @@ -1188,6 +1202,13 @@ localChecks = void . check check e = recurse e recurse = astMap identityMapper {mapOnExp = check} + checkEquality loc t = + unless (supportsEquality t) $ + typeError loc mempty $ + "Comparing equality of values of type" + indent 2 (pretty t) + "which does not support equality." + bitWidth ty = 8 * intByteSize ty :: Int inBoundsI x (Signed t) = x >= -2 ^ (bitWidth t - 1) && x < 2 ^ (bitWidth t - 1) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 975b7ae9e1..679dc6317b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -438,7 +438,7 @@ lookupVar loc qn@(QualName qs name) = do -- TODO - qualify type names, like in the old type checker. pure t' Just EqualityF -> do - argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarEql (locOf loc)) + argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarFree (locOf loc) Unlifted) pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts diff --git a/tests/types/inference-error9.fut b/tests/types/inference-error9.fut index 0c6886c881..6cead89cec 100644 --- a/tests/types/inference-error9.fut +++ b/tests/types/inference-error9.fut @@ -2,4 +2,4 @@ -- == -- error: equality -def main 't (x: t) = x == x +def f 't (x: t) = x == x From 0e3867d45e0806bf91ce873736dd7172657e3b74 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 24 Dec 2024 11:26:37 +0100 Subject: [PATCH 265/296] Better error message. --- .../Futhark/TypeChecker/Constraints.hs | 21 ++++++++++++++++--- src/Language/Futhark/TypeChecker/Terms2.hs | 7 +++++-- tests/funcall-error1.fut | 6 +++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index b6d186b195..da5f631263 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -68,7 +68,11 @@ data Reason | -- | Arising from explicit ascription. ReasonAscription Loc Type Type | ReasonRetType Loc Type Type - | ReasonApply Loc (Maybe (QualName VName)) Exp Type Type + | ReasonApply Loc (Maybe (QualName VName), Int) Exp Type Type + | -- | Used when unifying a type with a function type in a function + -- application. If this unification fails, it means the supposed + -- function was not a function after all. + ReasonApplySplit Loc (QualName VName, Int) Exp | ReasonBranches Loc Type Type deriving (Eq, Show) @@ -78,6 +82,7 @@ instance Located Reason where locOf (ReasonAscription l _ _) = l locOf (ReasonRetType l _ _) = l locOf (ReasonApply l _ _ _ _) = l + locOf (ReasonApplySplit l _ _) = l locOf (ReasonBranches l _ _) = l data Ct @@ -347,16 +352,26 @@ cannotUnify reason notes bcs t1 t2 = do where header = case f of - Nothing -> + (Nothing, _) -> "Cannot apply function to" <+> dquotes (shorten $ group $ pretty e) <> " (invalid type)." - Just fname -> + (Just fname, _) -> "Cannot apply" <+> dquotes (pretty fname) <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) <> " (invalid type)." + ReasonApplySplit loc (fname, i) e -> + typeError loc notes $ + stack + [ "Cannot apply" + <+> dquotes (pretty fname) + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> ".", + "Function accepts only" <+> pretty i <+> "arguments." + ] ReasonBranches loc former latter -> do former' <- enrichType former latter' <- enrichType latter diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 679dc6317b..b2fae2a1cb 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -698,7 +698,7 @@ checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do rhs = arrayOf (toShape (SVar m)) a ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) let reason = case arg of - Just arg' -> ReasonApply (locOf arg) (fst fname) arg' lhs rhs + Just arg' -> ReasonApply (locOf arg) fname arg' lhs rhs Nothing -> Reason (locOf loc) ctEq reason lhs rhs debugTraceM 3 $ @@ -741,7 +741,10 @@ checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do split ftype' = do a <- newType loc Lifted "arg" NoUniqueness b <- newType loc Lifted "res" Nonunique - ctEq (Reason (locOf loc)) ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b + let reason = case (fname, arg) of + ((Just fname', i), Just arg') -> ReasonApplySplit (locOf loc) (fname', i) arg' + _ -> Reason (locOf loc) + ctEq reason ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) checkApply :: diff --git a/tests/funcall-error1.fut b/tests/funcall-error1.fut index 8f5949246c..c9791213af 100644 --- a/tests/funcall-error1.fut +++ b/tests/funcall-error1.fut @@ -1,7 +1,7 @@ -- Test that functions accept only the right number of arguments. -- == --- error: Cannot apply "f" +-- error: 2 arguments -def f(x: i32) (y: f64): f64 = f64.i32 (x) + y +def f (x: i32) (y: f64) : f64 = f64.i32 (x) + y -def main: f64 = f 2 2.0 3 +def main : f64 = f 2 2.0 3 From 235e6dadb0922abeccccd676fd77e9f658460d76 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 24 Dec 2024 12:01:39 +0100 Subject: [PATCH 266/296] Further improve this. --- .../Futhark/TypeChecker/Constraints.hs | 20 +++++++++++++++---- src/Language/Futhark/TypeChecker/Terms2.hs | 6 +++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index da5f631263..56df66e7c9 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -72,7 +72,7 @@ data Reason | -- | Used when unifying a type with a function type in a function -- application. If this unification fails, it means the supposed -- function was not a function after all. - ReasonApplySplit Loc (QualName VName, Int) Exp + ReasonApplySplit Loc (Maybe (QualName VName), Int) Exp Type | ReasonBranches Loc Type Type deriving (Eq, Show) @@ -82,7 +82,7 @@ instance Located Reason where locOf (ReasonAscription l _ _) = l locOf (ReasonRetType l _ _) = l locOf (ReasonApply l _ _ _ _) = l - locOf (ReasonApplySplit l _ _) = l + locOf (ReasonApplySplit l _ _ _) = l locOf (ReasonBranches l _ _) = l data Ct @@ -362,16 +362,28 @@ cannotUnify reason notes bcs t1 t2 = do <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) <> " (invalid type)." - ReasonApplySplit loc (fname, i) e -> + ReasonApplySplit loc (fname, 0) _ ftype -> typeError loc notes $ stack [ "Cannot apply" - <+> dquotes (pretty fname) + <+> fname' + <+> "as function, as it has non-function type:" + indent 2 (align $ pretty ftype) + ] + where + fname' = maybe "expression" (dquotes . pretty) fname + ReasonApplySplit loc (fname, i) e _ -> + typeError loc notes $ + stack + [ "Cannot apply" + <+> fname' <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) <> ".", "Function accepts only" <+> pretty i <+> "arguments." ] + where + fname' = maybe "expression" (dquotes . pretty) fname ReasonBranches loc former latter -> do former' <- enrichType former latter' <- enrichType latter diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index b2fae2a1cb..77486e00e8 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -741,9 +741,9 @@ checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do split ftype' = do a <- newType loc Lifted "arg" NoUniqueness b <- newType loc Lifted "res" Nonunique - let reason = case (fname, arg) of - ((Just fname', i), Just arg') -> ReasonApplySplit (locOf loc) (fname', i) arg' - _ -> Reason (locOf loc) + let reason = case arg of + Just arg' -> ReasonApplySplit (locOf loc) fname arg' ftype' + Nothing -> Reason $ locOf loc ctEq reason ftype' $ Scalar $ Arrow NoUniqueness Unnamed Observe a $ RetType [] b pure (a, b `setUniqueness` NoUniqueness) From 9d22ce6bba7702f1970d445316526b2240fda8ee Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 24 Dec 2024 12:21:52 +0100 Subject: [PATCH 267/296] Different error. --- tests/types/inference-error7.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/types/inference-error7.fut b/tests/types/inference-error7.fut index 0c424be153..c3a5a0fcce 100644 --- a/tests/types/inference-error7.fut +++ b/tests/types/inference-error7.fut @@ -1,5 +1,5 @@ -- Ambiguous equality type. -- == --- error: ambiguous +-- error: does not support equality def add x y = x == y From 6503e99901b62acbcbf1e578b7e3c0ea7365b766 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 3 Jan 2025 19:06:14 +0100 Subject: [PATCH 268/296] Detect scope violations. --- src/Language/Futhark/TypeChecker/Constraints.hs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 56df66e7c9..53b37cc77d 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -695,6 +695,8 @@ scopeCheck reason v v_lvl ty = do case ty_v_info of Just (Right (TyVarParam ty_v_lvl _ _)) | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + Just (Right (TyVarSol ty')) -> + mapM_ check $ typeVars ty' _ -> pure () -- If a type variable has a liftedness constraint, we propagate that From 024e3ccc071212a84bc12415eaac531ffa6b3454 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 3 Jan 2025 20:05:56 +0100 Subject: [PATCH 269/296] Careful with patternType. --- src/Language/Futhark/TypeChecker/Terms2.hs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 77486e00e8..c24e303048 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -595,13 +595,13 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do ps' <- forM ps $ \p -> do p_t <- newType (srclocOf p) Lifted "t" Observe checkPat' p $ Ascribed p_t - t' <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' + t' <- newTypeWithConstr loc "t" Observe n =<< mapM (asType . patternType) ps' ctEq (Reason (locOf loc)) t' t t'' <- asStructType t' pure $ PatConstr n (Info $ toParam Observe t'') ps' loc checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps - t <- newTypeWithConstr loc "t" Observe n $ map (toType . patternType) ps' + t <- newTypeWithConstr loc "t" Observe n =<< mapM (asType . patternType) ps' t' <- asStructType t pure $ PatConstr n (Info $ toParam Observe t') ps' loc @@ -807,13 +807,13 @@ checkCases :: Type -> NE.NonEmpty (CaseBase NoInfo VName) -> TermM (NE.NonEmpty (CaseBase Info VName), Type) -checkCases mt rest_cs = - case NE.uncons rest_cs of - (c, Nothing) -> do - (c', t) <- checkCase mt c - pure (NE.singleton c', t) - (c, Just cs) -> do - (c', c_t) <- checkCase mt c +checkCases mt rest_cs = do + let (c, rest_cs') = NE.uncons rest_cs + (c', c_t) <- checkCase mt c + case rest_cs' of + Nothing -> + pure (NE.singleton c', c_t) + Just cs -> do (cs', cs_t) <- checkCases mt cs ctEq (ReasonBranches (locOf c) c_t cs_t) c_t cs_t pure (NE.cons c' cs', c_t) From 4a369e272310fb5e500f83278a42f09ebc24d073 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 3 Jan 2025 20:14:48 +0100 Subject: [PATCH 270/296] New error message. --- tests/shapes/funshape2.fut | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/shapes/funshape2.fut b/tests/shapes/funshape2.fut index 4fb6523f0f..dc17af92aa 100644 --- a/tests/shapes/funshape2.fut +++ b/tests/shapes/funshape2.fut @@ -1,4 +1,4 @@ -- == --- error: scope violation +-- error: causality -def main xs = (\f' -> f' (filter (>0) xs)) (\_ -> 0) +def main xs = (\f' -> f' (filter (> 0) xs)) (\_ -> 0) From a1e4607f8e9d8c84008804276b8c8191e46a8993 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 11 Jan 2025 22:18:18 +0100 Subject: [PATCH 271/296] Move this out. --- src/Language/Futhark/TypeChecker/Types.hs | 48 ++++++++++++----------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 096551cd4d..9e0cf4057d 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -439,6 +439,32 @@ applyType ps t args = substTypesAny (`M.lookup` substs) t mkSubst p a = error $ "applyType mkSubst: cannot substitute " ++ prettyString a ++ " for " ++ prettyString p +-- In case we are substituting the same RetType in multiple +-- places, we must ensure each instance is given distinct +-- dimensions. E.g. substituting 'a ↦ ?[n].[n]bool' into '(a,a)' +-- should give '?[n][m].([n]bool,[m]bool)'. +-- +-- XXX: the size names we invent here not globally unique. This +-- is _probably_ not a problem, since substituting types with +-- outermost non-null existential sizes is done only when type +-- checking modules and monomorphising. +freshDims :: + (Monoid as) => + RetTypeBase Size as -> + State [VName] (RetTypeBase Size as) +freshDims (RetType [] t) = pure $ RetType [] t +freshDims (RetType ext t) = do + seen_ext <- get + if not $ any (`elem` seen_ext) ext + then pure $ RetType ext t + else do + let start = maximum $ map baseTag seen_ext + ext' = zipWith VName (map baseName ext) [start + 1 ..] + mkSubst = ExpSubst . flip sizeFromName mempty . qualName + extsubsts = M.fromList $ zip ext $ map mkSubst ext' + RetType [] t' = substTypesRet (`M.lookup` extsubsts) t + pure $ RetType ext' t' + substTypesRet :: (Monoid u) => (VName -> Maybe (Subst (RetTypeBase Size u))) -> @@ -447,28 +473,6 @@ substTypesRet :: substTypesRet lookupSubst ot = uncurry (flip RetType) $ runState (onType ot) [] where - -- In case we are substituting the same RetType in multiple - -- places, we must ensure each instance is given distinct - -- dimensions. E.g. substituting 'a ↦ ?[n].[n]bool' into '(a,a)' - -- should give '?[n][m].([n]bool,[m]bool)'. - -- - -- XXX: the size names we invent here not globally unique. This - -- is _probably_ not a problem, since substituting types with - -- outermost non-null existential sizes is done only when type - -- checking modules and monomorphising. - freshDims (RetType [] t) = pure $ RetType [] t - freshDims (RetType ext t) = do - seen_ext <- get - if not $ any (`elem` seen_ext) ext - then pure $ RetType ext t - else do - let start = maximum $ map baseTag seen_ext - ext' = zipWith VName (map baseName ext) [start + 1 ..] - mkSubst = ExpSubst . flip sizeFromName mempty . qualName - extsubsts = M.fromList $ zip ext $ map mkSubst ext' - RetType [] t' = substTypesRet (`M.lookup` extsubsts) t - pure $ RetType ext' t' - onType :: forall as. (Monoid as) => From d04c3667f0fa499407374bc5ea5816ea381d53d7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 28 Jan 2025 12:02:29 +0100 Subject: [PATCH 272/296] Separate type constraints and AUTOMAP constraints. --- .../Futhark/TypeChecker/Constraints.hs | 9 ++-- src/Language/Futhark/TypeChecker/Rank.hs | 46 +++++++++++-------- src/Language/Futhark/TypeChecker/Terms2.hs | 16 +++++-- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 53b37cc77d..363735b0b8 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,3 +1,5 @@ +-- | Constraint solver for solving type equations produced +-- post-AUTOMAP. module Language.Futhark.TypeChecker.Constraints ( Reason (..), SVar, @@ -85,21 +87,17 @@ instance Located Reason where locOf (ReasonApplySplit l _ _ _) = l locOf (ReasonBranches l _ _) = l -data Ct - = CtEq Reason Type Type - | CtAM Reason SVar SVar (Shape SComp) +data Ct = CtEq Reason Type Type deriving (Show) ctReason :: Ct -> Reason ctReason (CtEq r _ _) = r -ctReason (CtAM r _ _ _) = r instance Located Ct where locOf = locOf . ctReason instance Pretty Ct where pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 - pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" type Constraints = [Ct] @@ -684,7 +682,6 @@ solveCt :: Ct -> SolveM () solveCt ct = case ct of CtEq reason t1 t2 -> solveEq reason mempty t1 t2 - CtAM {} -> pure () -- Good vibes only. scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () scopeCheck reason v v_lvl ty = do diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index deb578facb..da20eee308 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -1,6 +1,7 @@ module Language.Futhark.TypeChecker.Rank ( rankAnalysis, rankAnalysis1, + CtAM (..), ) where @@ -25,6 +26,14 @@ import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe +data CtAM = CtAM Reason SVar SVar (Shape SComp) + +instance Located CtAM where + locOf (CtAM r _ _ _) = locOf r + +instance Pretty CtAM where + pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" + type LSum = LP.LSum VName Int type Constraint = LP.Constraint VName Int @@ -84,7 +93,6 @@ distribAndSplitArrows (CtEq r t1 t2) = t1r' = t1r `setUniqueness` NoUniqueness t2r' = t2r `setUniqueness` NoUniqueness splitArrows c = [c] -distribAndSplitArrows ct = [ct] distribAndSplitCnstrs :: Ct -> [Ct] distribAndSplitCnstrs ct@(CtEq r t1 t2) = @@ -103,7 +111,6 @@ distribAndSplitCnstrs ct@(CtEq r t1 t2) = splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) = concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2) splitCnstrs _ = [] -distribAndSplitCnstrs ct = [ct] data RankState = RankState { rankBinVars :: Map VName VName, @@ -148,7 +155,9 @@ addObj sv = addCt :: Ct -> RankM () addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 -addCt (CtAM _ r m f) = do + +addCtAM :: CtAM -> RankM () +addCtAM (CtAM _ r m f) = do b_r <- binVar r b_m <- binVar m b_max <- VName "c_max" <$> incCounter @@ -168,8 +177,8 @@ addTyVarInfo tv (_, TyVarRecord {}) = addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 -mkLinearProg :: [Ct] -> TyVars -> LinearProg -mkLinearProg cs tyVars = +mkLinearProg :: [Ct] -> [CtAM] -> TyVars -> LinearProg +mkLinearProg cs cs_am tyVars = LP.LinearProg { optType = Minimize, objective = rankObj finalState, @@ -187,6 +196,7 @@ mkLinearProg cs tyVars = } buildLP = do mapM_ addCt cs + mapM_ addCtAM cs_am mapM_ (uncurry addTyVarInfo) $ M.toList tyVars finalState = flip execState initState $ runRankM buildLP @@ -249,7 +259,7 @@ solveRankILP loc prog = do rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> - [Ct] -> + ([Ct], [CtAM]) -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> @@ -261,8 +271,8 @@ rankAnalysis1 :: Exp, Maybe (TypeExp Exp VName) ) -rankAnalysis1 loc cs tyVars artificial params body retdecl = do - solutions <- rankAnalysis loc cs tyVars artificial params body retdecl +rankAnalysis1 loc (cs, cs_am) tyVars artificial params body retdecl = do + solutions <- rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl case solutions of [sol] -> pure sol sols -> do @@ -277,7 +287,7 @@ rankAnalysis1 loc cs tyVars artificial params body retdecl = do rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> - [Ct] -> + ([Ct], [CtAM]) -> TyVars -> M.Map TyVar Type -> [Pat ParamType] -> @@ -290,9 +300,9 @@ rankAnalysis :: Maybe (TypeExp Exp VName) ) ] -rankAnalysis _ [] tyVars artificial params body retdecl = +rankAnalysis _ ([], []) tyVars artificial params body retdecl = pure [(([], artificial, tyVars), params, body, retdecl)] -rankAnalysis loc cs tyVars artificial params body retdecl = do +rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl = do debugTraceM 3 $ unlines [ "##rankAnalysis", @@ -301,8 +311,8 @@ rankAnalysis loc cs tyVars artificial params body retdecl = do "cs':", unlines $ map prettyString cs' ] - rank_maps <- solveRankILP loc (mkLinearProg cs' tyVars) - cts_tyvars' <- mapM (substRankInfo cs artificial tyVars) rank_maps + rank_maps <- solveRankILP loc (mkLinearProg cs' cs_am tyVars) + cts_tyvars' <- mapM (substRankInfo (cs, cs_am) artificial tyVars) rank_maps let bodys = map (`updAM` body) rank_maps params' = map ((`map` params) . updAMPat) rank_maps retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps @@ -316,19 +326,16 @@ type RankMap = M.Map VName Int substRankInfo :: (MonadTypeChecker m) => - [Ct] -> + ([Ct], [CtAM]) -> M.Map VName Type -> TyVars -> RankMap -> m ([Ct], M.Map VName Type, TyVars) -substRankInfo cs artificial tyVars rankmap = do +substRankInfo (cs, _cs_am) artificial tyVars rankmap = do ((cs', artificial', tyVars'), new_cs, new_tyVars) <- runSubstT tyVars rankmap $ - (,,) <$> substRanks (filter (not . isCtAM) cs) <*> traverse substRanks artificial <*> traverse substRanks tyVars + (,,) <$> substRanks cs <*> traverse substRanks artificial <*> traverse substRanks tyVars pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') - where - isCtAM (CtAM {}) = True - isCtAM _ = False runSubstT :: (MonadTypeChecker m) => TyVars -> RankMap -> SubstT m a -> m (a, [Ct], TyVars) runSubstT tyVars rankmap (SubstT m) = do @@ -443,7 +450,6 @@ instance SubstRanks (TypeBase SComp u) where instance SubstRanks Ct where substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 - substRanks _ = error "" instance SubstRanks TyVarInfo where substRanks tv@TyVarFree {} = pure tv diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index c24e303048..cb1cea2246 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -112,6 +112,7 @@ data TermEnv = TermEnv -- generating unique names, as these will be user-visible. data TermState = TermState { termConstraints :: Constraints, + termAM :: [CtAM], termTyVars :: TyVars, termTyParams :: TyParams, termCounter :: !Int, @@ -193,6 +194,7 @@ runTermM (TermM m) = do initial_state = TermState { termConstraints = mempty, + termAM = mempty, termTyVars = mempty, termTyParams = mempty, termWarnings = mempty, @@ -311,7 +313,10 @@ ctEq reason t1 t2 = t2' = t2 `setUniqueness` NoUniqueness ctAM :: Reason -> SVar -> SVar -> Shape SComp -> TermM () -ctAM reason r m f = addCt $ CtAM reason r m f +ctAM reason r m f = + modify $ \s -> s {termAM = ct : termAM s} + where + ct = CtAM reason r m f localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -1370,6 +1375,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do pure (params', body', retdecl') cts <- gets termConstraints + cts_am <- gets termAM tyvars <- gets termTyVars typarams <- gets termTyParams artificial <- gets termArtificial @@ -1389,7 +1395,7 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do ] onRankSolution typarams - =<< rankAnalysis1 loc cts tyvars artificial params' body' retdecl' + =<< rankAnalysis1 loc (cts, cts_am) tyvars artificial params' body' retdecl' where onRankSolution typarams ((cts', artificial, tyvars'), params', body'', retdecl') = do solution <- @@ -1430,11 +1436,12 @@ checkSingleExp :: checkSingleExp e = runTermM $ do e' <- checkExp e cts <- gets termConstraints + cts_am <- gets termAM tyvars <- gets termTyVars typarams <- gets termTyParams artificial <- gets termArtificial ((cts', _artificial', tyvars'), _, e'', _) <- - rankAnalysis1 (srclocOf e') cts tyvars artificial [] e' Nothing + rankAnalysis1 (srclocOf e') (cts, cts_am) tyvars artificial [] e' Nothing case solve cts' typarams tyvars' of Left err -> pure (Left err, e'') Right (unconstrained, solution) -> do @@ -1450,12 +1457,13 @@ checkSizeExp :: checkSizeExp e = runTermM $ do e' <- checkSizeExp' e cts <- gets termConstraints + cts_am <- gets termAM tyvars <- gets termTyVars typarams <- gets termTyParams artificial <- gets termArtificial (cts_tyvars', _, es', _) <- - L.unzip4 <$> rankAnalysis (srclocOf e) cts tyvars artificial [] e' Nothing + L.unzip4 <$> rankAnalysis (srclocOf e) (cts, cts_am) tyvars artificial [] e' Nothing solutions <- forM cts_tyvars' $ \(cts', _artificial', tyvars') -> From c08ef7d13ab1d41040c1ef7b3ca9ff597a89b8c1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 28 Jan 2025 16:12:13 +0100 Subject: [PATCH 273/296] Improve comments. --- src/Language/Futhark/TypeChecker/Terms2.hs | 43 ++++++++-------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index cb1cea2246..eee234a221 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -2,16 +2,17 @@ -- -- The strategy is to split type checking into two (main) passes: -- --- 1) A size-agnostic pass that generates constraints (type Ct) which --- are then solved offline to find a solution. This produces an AST --- where most of the type annotations are just references to type --- variables. Further, all the size-specific annotations (e.g. --- existential sizes) just contain dummy values, such as empty lists. --- The constraints use a type representation where all dimensions are --- the same. However, we do try to take to store the sizes resulting --- from explicit type ascriptions - these cannot refer to inferred --- existentials, so it is safe to resolve them here. We don't do --- anything with this information, however. +-- 1) A size-agnostic pass that generates type constraints (type Ct) +-- and AUTOMAP constraints (type CtAM) which are later solved offline +-- to find a solution. This produces an AST where most of the type +-- annotations are just references to type variables. Further, all the +-- size-specific annotations (e.g. existential sizes) just contain +-- dummy values, such as empty lists. The constraints use a type +-- representation where all dimensions are the same. However, we do +-- try to store the sizes resulting from explicit type ascriptions - +-- these cannot refer to inferred existentials, so it is safe to +-- resolve them here. We don't do anything with this information, +-- however. -- -- 2) Pass (1) has given us a program where we know the types of -- everything, but the sizes of nothing. Pass (2) then does @@ -20,21 +21,6 @@ -- type of everything. This can be implemented using online constraint -- solving (as before), or perhaps a completely syntax-driven -- approach. --- --- As of this writing, only the constraint generation part of pass (1) --- has been implemented, and it is very likely that some of the --- constraints are actually wrong. Next step is to imlement the --- solver. Currently all we do is dump the constraints to the --- terminal. --- --- Also, no thought whatsoever has been put into quality of type --- errors yet. However, I think an approach based on tacking source --- information onto constraints should work well, as all constraints --- ultimately originate from some bit of program syntax. --- --- Also no thought has been put into how to handle the liftedness --- stuff. Since it does not really affect choices made during --- inference, perhaps we can do it in a post-inference check. module Language.Futhark.TypeChecker.Terms2 ( checkValDef, checkSingleExp, @@ -1353,6 +1339,7 @@ generaliseAndDefaults unconstrained solution t = do M.fromList (map (,Scalar (Record mempty)) unconstrained') <> solution' ) +-- | Type check a single value definition. checkValDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), @@ -1430,6 +1417,8 @@ checkValDef (fname, retdecl, tparams, params, body, loc) = runTermM $ do onArtificial artificial solution = M.map (substTyVars (`M.lookup` solution) . first (const ())) artificial <> solution +-- | Type check a single expression, which may have a polymorphic +-- type. checkSingleExp :: ExpBase NoInfo VName -> TypeM (Either TypeError ([TypeParam], M.Map TyVar (TypeBase () NoUniqueness)), Exp) @@ -1449,8 +1438,8 @@ checkSingleExp e = runTermM $ do x <- generaliseAndDefaults unconstrained solution $ first (const ()) e_t pure (Right x, e'') --- | Type-check a single size expression in isolation. This expression may --- turn out to be polymorphic, in which case it is unified with i64. +-- | Type-check a single size expression in isolation, which must have +-- type @i64@. checkSizeExp :: ExpBase NoInfo VName -> TypeM (Either TypeError ([UnconTyVar], M.Map TyVar (TypeBase () NoUniqueness)), Exp) From 6df3d6230adcd784eb3e5fcce9deb65c55a2053c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 28 Jan 2025 22:39:06 +0100 Subject: [PATCH 274/296] Separate constraints from solving. Also, be precise about when we move from SComp shapes to () shapes. --- futhark.cabal | 1 + .../Futhark/TypeChecker/Constraints.hs | 734 ++---------------- src/Language/Futhark/TypeChecker/Rank.hs | 166 ++-- src/Language/Futhark/TypeChecker/Terms2.hs | 22 +- src/Language/Futhark/TypeChecker/TySolve.hs | 668 ++++++++++++++++ src/Language/Futhark/TypeChecker/Unify.hs | 6 +- 6 files changed, 818 insertions(+), 779 deletions(-) create mode 100644 src/Language/Futhark/TypeChecker/TySolve.hs diff --git a/futhark.cabal b/futhark.cabal index 31fb133563..a2e0ef6dcf 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -425,6 +425,7 @@ library Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption Language.Futhark.TypeChecker.Constraints + Language.Futhark.TypeChecker.TySolve Language.Futhark.TypeChecker.Error Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 363735b0b8..a4c4971afa 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -4,35 +4,23 @@ module Language.Futhark.TypeChecker.Constraints ( Reason (..), SVar, SComp (..), - Type, - toType, - Ct (..), - Constraints, + CtType, + CtTy (..), + CtAM (..), TyVarInfo (..), + Level, TyVar, TyVars, TyParams, - Solution, - UnconTyVar, - solve, ) where -import Control.Monad -import Control.Monad.Except -import Control.Monad.State -import Data.Bifunctor -import Data.List qualified as L import Data.Loc import Data.Map qualified as M -import Data.Maybe -import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark -import Language.Futhark.TypeChecker.Error -import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) -import Language.Futhark.TypeChecker.Types (substTyVars) +-- | A shape variable. type SVar = VName -- | A shape component. `SDim` is a single dimension of unspecified @@ -51,34 +39,28 @@ instance Pretty SComp where instance Pretty (Shape SComp) where pretty = mconcat . map pretty . shapeDims --- | The type representation used by the constraint solver. Agnostic --- to sizes. -type Type = TypeBase SComp NoUniqueness - --- | Careful when using this on something that already has an SComp --- size: it will throw away information by converting them to SDim. -toType :: TypeBase Size u -> TypeBase SComp u -toType = first (const SDim) +-- | The type representation used by the constraint solver. +type CtType d = TypeBase d NoUniqueness -- | The reason for a type constraint. Used to generate type error -- messages. The expected type is always the first one. -data Reason +data Reason t = -- | No particular reason. Reason Loc | -- | Arising from pattern match. - ReasonPatMatch Loc (PatBase NoInfo VName ParamType) Type + ReasonPatMatch Loc (PatBase NoInfo VName ParamType) t | -- | Arising from explicit ascription. - ReasonAscription Loc Type Type - | ReasonRetType Loc Type Type - | ReasonApply Loc (Maybe (QualName VName), Int) Exp Type Type + ReasonAscription Loc t t + | ReasonRetType Loc t t + | ReasonApply Loc (Maybe (QualName VName), Int) Exp t t | -- | Used when unifying a type with a function type in a function -- application. If this unification fails, it means the supposed -- function was not a function after all. - ReasonApplySplit Loc (Maybe (QualName VName), Int) Exp Type - | ReasonBranches Loc Type Type - deriving (Eq, Show) + ReasonApplySplit Loc (Maybe (QualName VName), Int) Exp t + | ReasonBranches Loc t t + deriving (Eq, Show, Functor, Foldable, Traversable) -instance Located Reason where +instance Located (Reason t) where locOf (Reason l) = l locOf (ReasonPatMatch l _ _) = l locOf (ReasonAscription l _ _) = l @@ -87,46 +69,58 @@ instance Located Reason where locOf (ReasonApplySplit l _ _ _) = l locOf (ReasonBranches l _ _) = l -data Ct = CtEq Reason Type Type +-- | A type constraint. +data CtTy d = CtEq (Reason (CtType d)) (TypeBase d NoUniqueness) (TypeBase d NoUniqueness) deriving (Show) -ctReason :: Ct -> Reason +ctReason :: CtTy d -> Reason (CtType d) ctReason (CtEq r _ _) = r -instance Located Ct where +instance Located (CtTy d) where locOf = locOf . ctReason -instance Pretty Ct where +instance Pretty (CtTy Size) where + pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 + +instance Pretty (CtTy SComp) where pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 -type Constraints = [Ct] +instance Pretty (CtTy ()) where + pretty (CtEq _ t1 t2) = pretty t1 <+> "~" <+> pretty t2 -- | Information about a flexible type variable. Every type variable -- is associated with a location, which is the original syntax element -- that it is the type of. -data TyVarInfo +data TyVarInfo d = -- | Can be substituted with anything. TyVarFree Loc Liftedness | -- | Can only be substituted with these primitive types. TyVarPrim Loc [PrimType] | -- | Must be a record with these fields. - TyVarRecord Loc (M.Map Name Type) + TyVarRecord Loc (M.Map Name (CtType d)) | -- | Must be a sum type with these fields. - TyVarSum Loc (M.Map Name [Type]) + TyVarSum Loc (M.Map Name [CtType d]) deriving (Show, Eq) -instance Pretty TyVarInfo where - pretty (TyVarFree _ l) = "free" <+> pretty l - pretty (TyVarPrim _ pts) = "∈" <+> pretty pts - pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs - pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs +prettyTyVarInfo :: (Pretty (Shape d)) => TyVarInfo d -> Doc a +prettyTyVarInfo (TyVarFree _ l) = "free" <+> pretty l +prettyTyVarInfo (TyVarPrim _ pts) = "∈" <+> pretty pts +prettyTyVarInfo (TyVarRecord _ fs) = pretty $ Scalar $ Record fs +prettyTyVarInfo (TyVarSum _ cs) = pretty $ Scalar $ Sum cs + +instance Pretty (TyVarInfo ()) where + pretty = prettyTyVarInfo + +instance Pretty (TyVarInfo SComp) where + pretty = prettyTyVarInfo -instance Located TyVarInfo where +instance Located (TyVarInfo d) where locOf (TyVarFree loc _) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc +-- | The name of a type variable. type TyVar = VName -- | The level at which a type variable is bound. Higher means @@ -137,647 +131,15 @@ type Level = Int -- | If a VName is not in this map, it should be in the 'TyParams' - -- the exception is abstract types, which are just missing (and -- assumed to have smallest possible level). -type TyVars = M.Map TyVar (Level, TyVarInfo) +type TyVars d = M.Map TyVar (Level, TyVarInfo d) -- | Explicit type parameters. type TyParams = M.Map TyVar (Level, Liftedness, Loc) -data TyVarSol - = -- | Has been substituted with this. - TyVarSol Type - | -- | Is an explicit (rigid) type parameter in the source program. - TyVarParam Level Liftedness Loc - | -- | Not substituted yet; has this constraint. - TyVarUnsol TyVarInfo - deriving (Show) - -newtype SolverState = SolverState - { -- | Left means linked to this other type variable. - solverTyVars :: M.Map TyVar (Either VName TyVarSol) - } - -initialState :: TyParams -> TyVars -> SolverState -initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars - where - f (_lvl, info) = Right $ TyVarUnsol info - g (lvl, l, loc) = Right $ TyVarParam lvl l loc - -substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase SComp u) -substTyVar m v = - case M.lookup v m of - Just (Left v') -> substTyVar m v' - Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' - Just (Right TyVarParam {}) -> Nothing - Just (Right (TyVarUnsol {})) -> Nothing - Nothing -> Nothing - -maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) -maybeLookupTyVar orig = do - tyvars <- gets solverTyVars - let f v = case M.lookup v tyvars of - Nothing -> pure Nothing - Just (Left v') -> f v' - Just (Right info) -> pure $ Just info - f orig - -lookupTyVar :: TyVar -> SolveM (Either TyVarInfo Type) -lookupTyVar orig = - maybe bad unpack <$> maybeLookupTyVar orig - where - bad = error $ "Unknown tyvar: " <> prettyNameString orig - unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig - unpack (TyVarSol t) = Right t - unpack (TyVarUnsol info) = Left info - --- | Variable must be flexible. -lookupTyVarInfo :: TyVar -> SolveM TyVarInfo -lookupTyVarInfo v = do - r <- lookupTyVar v - case r of - Left info -> pure info - Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v - -setLink :: TyVar -> VName -> SolveM () -setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} - -setInfo :: TyVar -> TyVarSol -> SolveM () -setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} - --- | A solution maps a type variable to its substitution. This --- substitution is complete, in the sense there are no right-hand --- sides that contain a type variable. -type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) - --- | An unconstrained type variable comprises a name and (ironically) --- a constraint on how it can be instantiated. -type UnconTyVar = (VName, Liftedness) - -typeVar :: (Monoid u) => VName -> TypeBase dim u -typeVar v = Scalar $ TypeVar mempty (qualName v) [] - -solution :: SolverState -> ([UnconTyVar], Solution) -solution s = - ( mapMaybe unconstrained $ M.toList $ solverTyVars s, - M.mapMaybe mkSubst $ solverTyVars s - ) - where - mkSubst (Right (TyVarSol t)) = - Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t - mkSubst (Left v') = - Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ - mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts - mkSubst _ = Nothing - - unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) - unconstrained _ = Nothing - -newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} - deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) - --- Try to substitute as much information as we have. -enrichType :: Type -> SolveM Type -enrichType t = do - s <- get - pure $ substTyVars (substTyVar (solverTyVars s)) t - -typeError :: Loc -> Notes -> Doc () -> SolveM () -typeError loc notes msg = - throwError $ TypeError loc notes msg - -occursCheck :: Reason -> VName -> Type -> SolveM () -occursCheck reason v tp = do - vars <- gets solverTyVars - let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ - "Occurs check: cannot instantiate" - <+> prettyName v - <+> "with" - <+> pretty tp - <> "." - -unifySharedConstructors :: - Reason -> - BreadCrumbs -> - M.Map Name [Type] -> - M.Map Name [Type] -> - SolveM () -unifySharedConstructors reason bcs cs1 cs2 = - forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> - if length ts1 == length ts2 - then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 - else - typeError (locOf reason) mempty $ - "Cannot unify type with constructor" - indent 2 (pretty (Sum (M.singleton c ts1))) - "with type of constructor" - indent 2 (pretty (Sum (M.singleton c ts2))) - "because they differ in arity." - -unifySharedFields :: - Reason -> - BreadCrumbs -> - M.Map Name Type -> - M.Map Name Type -> - SolveM () -unifySharedFields reason bcs fs1 fs2 = - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> - solveEq reason (matchingField f <> bcs) ts1 ts2 - -scopeViolation :: Reason -> VName -> Type -> VName -> SolveM () -scopeViolation reason v1 ty v2 = - typeError (locOf reason) mempty $ - "Cannot unify type" - indent 2 (pretty ty) - "with" - <+> dquotes (prettyName v1) - <+> "(scope violation)." - "This is because" - <+> dquotes (prettyName v2) - <+> "is rigidly bound in a deeper scope." - -cannotUnify :: - Reason -> - Notes -> - BreadCrumbs -> - Type -> - Type -> - SolveM () -cannotUnify reason notes bcs t1 t2 = do - t1' <- enrichType t1 - t2' <- enrichType t2 - case reason of - Reason loc -> - typeError loc notes . stack $ - [ "Cannot unify", - indent 2 (pretty t1'), - "with", - indent 2 (pretty t2') - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - ReasonPatMatch loc pat value_t -> - typeError loc notes . stack $ - [ "Pattern", - indent 2 $ align $ pretty pat, - "cannot match value of type", - indent 2 $ align $ pretty value_t - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - ReasonAscription loc expected actual -> - typeError loc notes . stack $ - [ "Expression does not have expected type from type ascription.", - "Expected:" <+> align (pretty expected), - "Actual: " <+> align (pretty actual) - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - ReasonRetType loc expected actual -> do - expected' <- enrichType expected - actual' <- enrichType actual - typeError loc notes . stack $ - [ "Function body does not have expected type.", - "Expected:" <+> align (pretty expected'), - "Actual: " <+> align (pretty actual') - ] - <> [pretty bcs | not $ hasNoBreadCrumbs bcs] - ReasonApply loc f e expected actual -> do - expected' <- enrichType expected - actual' <- enrichType actual - typeError loc notes . stack $ - [ header, - "Expected:" <+> align (pretty expected'), - "Actual: " <+> align (pretty actual') - ] - where - header = - case f of - (Nothing, _) -> - "Cannot apply function to" - <+> dquotes (shorten $ group $ pretty e) - <> " (invalid type)." - (Just fname, _) -> - "Cannot apply" - <+> dquotes (pretty fname) - <+> "to" - <+> dquotes (align $ shorten $ group $ pretty e) - <> " (invalid type)." - ReasonApplySplit loc (fname, 0) _ ftype -> - typeError loc notes $ - stack - [ "Cannot apply" - <+> fname' - <+> "as function, as it has non-function type:" - indent 2 (align $ pretty ftype) - ] - where - fname' = maybe "expression" (dquotes . pretty) fname - ReasonApplySplit loc (fname, i) e _ -> - typeError loc notes $ - stack - [ "Cannot apply" - <+> fname' - <+> "to" - <+> dquotes (align $ shorten $ group $ pretty e) - <> ".", - "Function accepts only" <+> pretty i <+> "arguments." - ] - where - fname' = maybe "expression" (dquotes . pretty) fname - ReasonBranches loc former latter -> do - former' <- enrichType former - latter' <- enrichType latter - typeError loc notes . stack $ - [ "Branches differ in type.", - "Former:" <+> pretty former', - "Latter:" <+> pretty latter' - ] - --- Precondition: 'v' is currently flexible. -subTyVar :: Reason -> BreadCrumbs -> VName -> Type -> SolveM () -subTyVar reason bcs v t = do - occursCheck reason v t - v_info <- gets $ M.lookup v . solverTyVars - - -- Set a solution for v, then update info for t in case v has any - -- odd constraints. - setInfo v (TyVarSol t) - - case (v_info, t) of - (Just (Right (TyVarUnsol TyVarFree {})), _) -> - pure () - ( Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), - _ - ) -> - if t `elem` map (Scalar . Prim) v_pts - then pure () - else cannotUnify reason notes bcs (typeVar v) t - where - notes = - aNote $ - "Cannot instance type that must be one of" - indent 2 (pretty v_pts) - "with" - indent 2 (pretty t) - ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), - Scalar (Sum cs2) - ) -> - if all (`elem` M.keys cs2) (M.keys cs1) - then unifySharedConstructors reason bcs cs1 cs2 - else cannotUnify reason notes bcs (typeVar v) t - where - notes = - aNote $ - "Cannot match type with constructors" - indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) - "with type with constructors" - indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) - unsharedConstructorsMsg cs1 cs2 - ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), - _ - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type with constructors" - indent 2 (pretty (Sum cs1)) - "with type" - indent 2 (pretty t) - ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), - Scalar (Record fs2) - ) -> - if all (`elem` M.keys fs2) (M.keys fs1) - then unifySharedFields reason bcs fs1 fs2 - else - typeError (locOf reason) mempty $ - "Cannot unify record type with fields" - indent 2 (pretty (Record fs1)) - "with record type" - indent 2 (pretty (Record fs2)) - ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), - _ - ) -> - typeError (locOf reason) mempty $ - "Cannot unify record type with fields" - indent 2 (pretty (Record fs1)) - "with type" - indent 2 (pretty t) - -- - -- Internal error cases - (Just (Right TyVarSol {}), _) -> - error $ "Type variable already solved: " <> prettyNameString v - (Just (Right TyVarParam {}), _) -> - error $ "Cannot substitute type parameter: " <> prettyNameString v - (Just Left {}, _) -> - error $ "Type variable already linked: " <> prettyNameString v - (Nothing, _) -> - error $ "subTyVar: Nothing v: " <> prettyNameString v - --- Precondition: 'v' and 't' are both currently flexible. --- --- The purpose of this function is to combine the partial knowledge we --- may have about these two type variables. -unionTyVars :: Reason -> BreadCrumbs -> VName -> VName -> SolveM () -unionTyVars reason bcs v t = do - v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - t_info <- lookupTyVarInfo t - - -- Insert the link from v to t, and then update the info of t based - -- on the existing info of v and t. - setLink v t - - case (v_info, t_info) of - ( TyVarUnsol (TyVarFree _ v_l), - TyVarFree t_loc t_l - ) - | v_l /= t_l -> - setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) - -- When either is completely unconstrained. - (TyVarUnsol TyVarFree {}, _) -> - pure () - ( TyVarUnsol info, - TyVarFree {} - ) -> - setInfo t (TyVarUnsol info) - -- - -- TyVarPrim cases - ( TyVarUnsol (TyVarPrim _ v_pts), - TyVarPrim t_loc t_pts - ) -> - let pts = L.intersect v_pts t_pts - in if null pts - then - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" - indent 2 (pretty v_pts) - "with type that must be one of" - indent 2 (pretty t_pts) - else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) - ( TyVarUnsol (TyVarPrim _ v_pts), - TyVarRecord {} - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" - indent 2 (pretty v_pts) - "with type that must be a record." - ( TyVarUnsol (TyVarPrim _ v_pts), - TyVarSum {} - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" - indent 2 (pretty v_pts) - "with type that must be sum." - -- - -- TyVarSum cases - ( TyVarUnsol (TyVarSum _ cs1), - TyVarSum loc cs2 - ) -> do - unifySharedConstructors reason bcs cs1 cs2 - let cs3 = cs1 <> cs2 - setInfo t (TyVarUnsol (TyVarSum loc cs3)) - ( TyVarUnsol TyVarSum {}, - TyVarPrim _ pts - ) -> - typeError (locOf reason) mempty $ - "A sum type cannot be one of" - indent 2 (pretty pts) - ( TyVarUnsol (TyVarSum _ cs1), - TyVarRecord _ fs - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type with constructors" - indent 2 (pretty (Sum cs1)) - "with type" - indent 2 (pretty (Scalar (Record fs))) - -- - -- TyVarRecord cases - ( TyVarUnsol (TyVarRecord _ fs1), - TyVarRecord loc fs2 - ) -> do - unifySharedFields reason bcs fs1 fs2 - let fs3 = fs1 <> fs2 - setInfo t (TyVarUnsol (TyVarRecord loc fs3)) - ( TyVarUnsol TyVarRecord {}, - TyVarPrim _ pts - ) -> - typeError (locOf reason) mempty $ - "A record type cannot be one of" - indent 2 (pretty pts) - ( TyVarUnsol (TyVarRecord _ fs1), - TyVarSum _ cs - ) -> - typeError (locOf reason) mempty $ - "Cannot unify record type" - indent 2 (pretty (Record fs1)) - "with type" - indent 2 (pretty (Scalar (Sum cs))) - -- - -- Internal error cases - (TyVarSol {}, _) -> - alreadySolved - (TyVarParam {}, _) -> - isParam - where - unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v - alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v - alreadySolved = error $ "Type variable already solved: " <> prettyNameString v - isParam = error $ "Type name is a type parameter: " <> prettyNameString v - -unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a -unsharedConstructorsMsg cs1 cs2 = - "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." - where - missing = - filter (`notElem` M.keys cs1) (M.keys cs2) - ++ filter (`notElem` M.keys cs2) (M.keys cs1) - --- Unify at the root, emitting new equalities that must hold. -unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] -unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) - | pt1 == pt2 = Right [] -unify - (Scalar (TypeVar _ (QualName _ v1) targs1)) - (Scalar (TypeVar _ (QualName _ v2) targs2)) - | v1 == v2 = - Right $ mapMaybe f $ zip targs1 targs2 - where - f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) - f _ = Nothing -unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] - where - t1r' = t1r `setUniqueness` NoUniqueness - t2r' = t2r `setUniqueness` NoUniqueness -unify (Scalar (Record fs1)) (Scalar (Record fs2)) - | M.keys fs1 == M.keys fs2 = - Right $ - map (first matchingField) $ - M.toList $ - M.intersectionWith (,) fs1 fs2 - | Just n1 <- length <$> areTupleFields fs1, - Just n2 <- length <$> areTupleFields fs2, - n1 /= n2 = - Left $ - "Tuples have" - <+> pretty n1 - <+> "and" - <+> pretty n2 - <+> "elements respectively." - | otherwise = - let missing = - filter (`notElem` M.keys fs1) (M.keys fs2) - <> filter (`notElem` M.keys fs2) (M.keys fs1) - in Left $ - "unshared fields:" <+> commasep (map pretty missing) <> "." -unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) - | M.keys cs1 == M.keys cs2 = - fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do - if length ts1 == length ts2 - then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 - else Left mempty - | otherwise = - Left $ unsharedConstructorsMsg cs1 cs2 - where - cs' = M.toList $ M.intersectionWith (,) cs1 cs2 -unify t1 t2 - | Just t1' <- peelArray 1 t1, - Just t2' <- peelArray 1 t2 = - Right [(mempty, (t1', t2'))] -unify _ _ = Left mempty - -solveEq :: Reason -> BreadCrumbs -> Type -> Type -> SolveM () -solveEq reason obcs orig_t1 orig_t2 = do - solveCt' (obcs, (orig_t1, orig_t2)) - where - solveCt' (bcs, (t1, t2)) = do - tyvars <- gets solverTyVars - let flexible v = case M.lookup v tyvars of - Just (Left v') -> flexible v' - Just (Right (TyVarUnsol _)) -> True - Just (Right TyVarSol {}) -> False - Just (Right TyVarParam {}) -> False - Nothing -> False - sub t@(Scalar (TypeVar u (QualName [] v) [])) = - case M.lookup v tyvars of - Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (Right (TyVarSol t')) -> sub t' - _ -> t - sub t = t - case (sub t1, sub t2) of - ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), - t2'@(Scalar (TypeVar _ (QualName [] v2) [])) - ) - | v1 == v2 -> pure () - | otherwise -> - case (flexible v1, flexible v2) of - (False, False) -> cannotUnify reason mempty bcs t1 t2 - (True, False) -> subTyVar reason bcs v1 t2' - (False, True) -> subTyVar reason bcs v2 t1' - (True, True) -> unionTyVars reason bcs v1 v2 - (Scalar (TypeVar _ (QualName [] v1) []), t2') - | flexible v1 -> subTyVar reason bcs v1 t2' - (t1', Scalar (TypeVar _ (QualName [] v2) [])) - | flexible v2 -> subTyVar reason bcs v2 t1' - (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify reason (aNote details) bcs t1' t2' - Right eqs -> mapM_ solveCt' eqs - -solveCt :: Ct -> SolveM () -solveCt ct = - case ct of - CtEq reason t1 t2 -> solveEq reason mempty t1 t2 - -scopeCheck :: Reason -> TyVar -> Int -> Type -> SolveM () -scopeCheck reason v v_lvl ty = do - mapM_ check $ typeVars ty - where - check ty_v = do - ty_v_info <- gets $ M.lookup ty_v . solverTyVars - case ty_v_info of - Just (Right (TyVarParam ty_v_lvl _ _)) - | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v - Just (Right (TyVarSol ty')) -> - mapM_ check $ typeVars ty' - _ -> pure () - --- If a type variable has a liftedness constraint, we propagate that --- constraint to its solution. The actual checking for correct usage --- is done later. -liftednessCheck :: Liftedness -> Type -> SolveM () -liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do - v_info <- maybeLookupTyVar v - case v_info of - Nothing -> - -- Is an opaque type. - pure () - Just (TyVarSol v_ty) -> - liftednessCheck l v_ty - Just TyVarParam {} -> pure () - Just (TyVarUnsol (TyVarFree loc v_l)) - | l /= v_l -> - setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) - Just TyVarUnsol {} -> pure () -liftednessCheck _ (Scalar Prim {}) = pure () -liftednessCheck Lifted _ = pure () -liftednessCheck _ Array {} = pure () -liftednessCheck _ (Scalar Arrow {}) = pure () -liftednessCheck l (Scalar (Record fs)) = - mapM_ (liftednessCheck l) fs -liftednessCheck l (Scalar (Sum cs)) = - mapM_ (mapM_ $ liftednessCheck l) cs -liftednessCheck _ (Scalar TypeVar {}) = pure () +data CtAM = CtAM (Reason (CtType SComp)) SVar SVar (Shape SComp) -solveTyVar :: (VName, (Level, TyVarInfo)) -> SolveM () -solveTyVar (tv, (_, TyVarRecord loc fs1)) = do - tv_t <- lookupTyVar tv - case tv_t of - Left _ -> - typeError loc mempty $ - "Type" - <+> prettyName tv - <+> "is ambiguous." - "Must be a record with fields" - indent 2 (pretty (Scalar (Record fs1))) - Right _ -> - pure () -solveTyVar (tv, (_, TyVarSum loc cs1)) = do - tv_t <- lookupTyVar tv - case tv_t of - Left _ -> - typeError loc mempty $ - "Type is ambiguous." - "Must be a sum type with constructors" - indent 2 (pretty (Scalar (Sum cs1))) - Right _ -> pure () -solveTyVar (tv, (lvl, TyVarFree loc l)) = do - tv_t <- lookupTyVar tv - case tv_t of - Right ty -> do - scopeCheck (Reason loc) tv lvl ty - liftednessCheck l ty - _ -> pure () -solveTyVar (tv, (_, TyVarPrim loc pts)) = do - tv_t <- lookupTyVar tv - case tv_t of - Right (Scalar (Prim ty)) - | [ty] == pts -> - setInfo tv $ TyVarSol $ Scalar $ Prim ty - Right ty - | ty `elem` map (Scalar . Prim) pts -> pure () - | otherwise -> - typeError loc mempty $ - "Numeric constant inferred to be of type" - indent 2 (align (pretty ty)) - "which is not possible." - _ -> pure () +instance Located CtAM where + locOf (CtAM r _ _ _) = locOf r -solve :: - Constraints -> - TyParams -> - TyVars -> - Either TypeError ([UnconTyVar], Solution) -solve constraints typarams tyvars = - second solution - . runExcept - . flip execStateT (initialState typarams tyvars) - . runSolveM - $ do - mapM_ solveCt constraints - mapM_ solveTyVar (M.toList tyvars) -{-# NOINLINE solve #-} +instance Pretty CtAM where + pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index da20eee308..c471395422 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -26,22 +26,12 @@ import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad import System.IO.Unsafe -data CtAM = CtAM Reason SVar SVar (Shape SComp) - -instance Located CtAM where - locOf (CtAM r _ _ _) = locOf r - -instance Pretty CtAM where - pretty (CtAM _ r m _) = prettyName r <+> "=" <+> "•" <+> "∨" <+> prettyName m <+> "=" <+> "•" - type LSum = LP.LSum VName Int type Constraint = LP.Constraint VName Int type LinearProg = LP.LinearProg VName Int -type ScalarType = ScalarTypeBase SComp NoUniqueness - class Rank a where rank :: a -> LSum @@ -55,7 +45,7 @@ instance Rank SComp where instance Rank (Shape SComp) where rank = foldr (\d r -> rank d ~+~ r) (constant 0) . shapeDims -instance Rank ScalarType where +instance Rank (ScalarTypeBase SComp u) where rank Prim {} = constant 0 rank (TypeVar _ (QualName [] v) []) = var v rank (TypeVar {}) = constant 0 @@ -63,11 +53,11 @@ instance Rank ScalarType where rank (Record {}) = constant 0 rank (Sum {}) = constant 0 -instance Rank Type where +instance Rank (TypeBase SComp u) where rank (Scalar t) = rank t rank (Array _ shape t) = rank shape ~+~ rank t -distribAndSplitArrows :: Ct -> [Ct] +distribAndSplitArrows :: CtTy d -> [CtTy d] distribAndSplitArrows (CtEq r t1 t2) = splitArrows $ CtEq r (distribute t1) (distribute t2) where @@ -94,7 +84,7 @@ distribAndSplitArrows (CtEq r t1 t2) = t2r' = t2r `setUniqueness` NoUniqueness splitArrows c = [c] -distribAndSplitCnstrs :: Ct -> [Ct] +distribAndSplitCnstrs :: CtTy d -> [CtTy d] distribAndSplitCnstrs ct@(CtEq r t1 t2) = ct : splitCnstrs (CtEq r (distribute1 t1) (distribute1 t2)) where @@ -153,7 +143,7 @@ addObj :: SVar -> RankM () addObj sv = modify $ \s -> s {rankObj = rankObj s ~+~ var sv} -addCt :: Ct -> RankM () +addCt :: CtTy SComp -> RankM () addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 addCtAM :: CtAM -> RankM () @@ -168,7 +158,7 @@ addCtAM (CtAM _ r m f) = do addObj m addObj tr -addTyVarInfo :: TyVar -> (Int, TyVarInfo) -> RankM () +addTyVarInfo :: TyVar -> (Int, TyVarInfo d) -> RankM () addTyVarInfo _ (_, TyVarFree {}) = pure () addTyVarInfo tv (_, TyVarPrim {}) = addConstraint $ rank tv ~==~ constant 0 @@ -177,7 +167,7 @@ addTyVarInfo tv (_, TyVarRecord {}) = addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 -mkLinearProg :: [Ct] -> [CtAM] -> TyVars -> LinearProg +mkLinearProg :: [CtTy SComp] -> [CtAM] -> TyVars d -> LinearProg mkLinearProg cs cs_am tyVars = LP.LinearProg { optType = Minimize, @@ -259,14 +249,14 @@ solveRankILP loc prog = do rankAnalysis1 :: (MonadTypeChecker m) => SrcLoc -> - ([Ct], [CtAM]) -> - TyVars -> - M.Map TyVar Type -> + ([CtTy SComp], [CtAM]) -> + TyVars SComp -> + M.Map TyVar (CtType SComp) -> [Pat ParamType] -> Exp -> Maybe (TypeExp Exp VName) -> m - ( ([Ct], M.Map TyVar Type, TyVars), + ( ([CtTy ()], M.Map TyVar (CtType ()), TyVars ()), [Pat ParamType], Exp, Maybe (TypeExp Exp VName) @@ -287,21 +277,22 @@ rankAnalysis1 loc (cs, cs_am) tyVars artificial params body retdecl = do rankAnalysis :: (MonadTypeChecker m) => SrcLoc -> - ([Ct], [CtAM]) -> - TyVars -> - M.Map TyVar Type -> + ([CtTy SComp], [CtAM]) -> + TyVars SComp -> + M.Map TyVar (CtType SComp) -> [Pat ParamType] -> Exp -> Maybe (TypeExp Exp VName) -> m - [ ( ([Ct], M.Map TyVar Type, TyVars), + [ ( ([CtTy ()], M.Map TyVar (CtType ()), TyVars ()), [Pat ParamType], Exp, Maybe (TypeExp Exp VName) ) ] -rankAnalysis _ ([], []) tyVars artificial params body retdecl = - pure [(([], artificial, tyVars), params, body, retdecl)] +rankAnalysis _ ([], []) tyVars artificial params body retdecl = do + (_, artificial', tyVars') <- substRankInfo ([], []) artificial tyVars mempty + pure [(([], artificial', tyVars'), params, body, retdecl)] rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl = do debugTraceM 3 $ unlines @@ -326,18 +317,26 @@ type RankMap = M.Map VName Int substRankInfo :: (MonadTypeChecker m) => - ([Ct], [CtAM]) -> - M.Map VName Type -> - TyVars -> + ([CtTy SComp], [CtAM]) -> + M.Map VName (CtType SComp) -> + TyVars SComp -> RankMap -> - m ([Ct], M.Map VName Type, TyVars) + m ([CtTy ()], M.Map VName (CtType ()), TyVars ()) substRankInfo (cs, _cs_am) artificial tyVars rankmap = do ((cs', artificial', tyVars'), new_cs, new_tyVars) <- runSubstT tyVars rankmap $ - (,,) <$> substRanks cs <*> traverse substRanks artificial <*> traverse substRanks tyVars + (,,) + <$> traverse substRanksCt cs + <*> traverse substRanksType artificial + <*> substRanksTyVars tyVars pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') -runSubstT :: (MonadTypeChecker m) => TyVars -> RankMap -> SubstT m a -> m (a, [Ct], TyVars) +runSubstT :: + (MonadTypeChecker m) => + TyVars SComp -> + RankMap -> + SubstT m a -> + m (a, [CtTy ()], TyVars ()) runSubstT tyVars rankmap (SubstT m) = do let env = SubstEnv @@ -364,14 +363,14 @@ newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) ) data SubstEnv = SubstEnv - { envTyVars :: TyVars, + { envTyVars :: TyVars SComp, envRanks :: RankMap } data SubstState = SubstState - { substTyVars :: TyVars, + { substTyVars :: TyVars (), substNewVars :: Map TyVar TyVar, - substNewCts :: [Ct] + substNewCts :: [CtTy ()] } instance MonadTrans SubstT where @@ -395,10 +394,10 @@ newTyVar t = do } pure t' -rankToShape :: (Monad m) => VName -> SubstT m (Shape SComp) +rankToShape :: (Monad m) => VName -> SubstT m (Shape ()) rankToShape x = do rs <- asks envRanks - pure $ Shape $ replicate (fromJust $ rs M.!? x) SDim + pure $ Shape $ replicate (fromJust $ rs M.!? x) () addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do @@ -416,51 +415,56 @@ addRankInfo t = do l = case tvinfo of TyVarFree _ tvinfo_l -> tvinfo_l _ -> Unlifted - modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo) $ substTyVars s} + tvinfo' <- substRanksTyVarInfo tvinfo + modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo') $ substTyVars s} modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree (locOf tvinfo) l) $ substTyVars s} -class SubstRanks a where - substRanks :: (MonadTypeChecker m) => a -> SubstT m a - -instance (SubstRanks a) => SubstRanks [a] where - substRanks = mapM substRanks - -instance SubstRanks (Shape SComp) where - substRanks = foldM (\s d -> (s <>) <$> instDim d) mempty - where - instDim SDim = pure $ Shape $ pure SDim - instDim (SVar x) = rankToShape x - -instance SubstRanks (TypeBase SComp u) where - substRanks t@(Scalar (TypeVar _ (QualName [] x) [])) = - addRankInfo x >> pure t - substRanks (Scalar (Arrow u p d ta (RetType retdims tr))) = do - ta' <- substRanks ta - tr' <- substRanks tr - pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) - substRanks (Scalar (Record fs)) = - Scalar . Record <$> traverse substRanks fs - substRanks (Scalar (Sum cs)) = - Scalar . Sum <$> (traverse . traverse) substRanks cs - substRanks (Array u shape t) = do - shape' <- substRanks shape - t' <- substRanks $ Scalar t - pure $ arrayOfWithAliases u shape' t' - substRanks t = pure t - -instance SubstRanks Ct where - substRanks (CtEq r t1 t2) = CtEq r <$> substRanks t1 <*> substRanks t2 - -instance SubstRanks TyVarInfo where - substRanks tv@TyVarFree {} = pure tv - substRanks tv@TyVarPrim {} = pure tv - substRanks (TyVarRecord loc fs) = - TyVarRecord loc <$> traverse substRanks fs - substRanks (TyVarSum loc cs) = - TyVarSum loc <$> (traverse . traverse) substRanks cs - -instance SubstRanks (Int, TyVarInfo) where - substRanks (lvl, tv) = (lvl,) <$> substRanks tv +substRanksShape :: (Monad m) => Shape SComp -> SubstT m (Shape ()) +substRanksShape = foldM (\s d -> (s <>) <$> instDim d) mempty + where + instDim SDim = pure $ Shape [()] + instDim (SVar x) = rankToShape x + +substRanksType :: (MonadTypeChecker m) => TypeBase SComp u -> SubstT m (TypeBase () u) +substRanksType (Scalar (TypeVar vn (QualName qs x) targs)) = do + when (null qs) $ addRankInfo x + targs' <- mapM onTypeArg targs + pure $ Scalar $ TypeVar vn (QualName qs x) targs' + where + onTypeArg (TypeArgType t) = TypeArgType <$> substRanksType t + -- SVar cannot occur as argument to abstract ype. + onTypeArg (TypeArgDim _) = pure $ TypeArgDim () +substRanksType (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanksType ta + tr' <- substRanksType tr + pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) +substRanksType (Scalar (Record fs)) = + Scalar . Record <$> traverse substRanksType fs +substRanksType (Scalar (Sum cs)) = + Scalar . Sum <$> (traverse . traverse) substRanksType cs +substRanksType (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +substRanksType (Array u shape t) = do + shape' <- substRanksShape shape + t' <- substRanksType $ Scalar t + pure $ arrayOfWithAliases u shape' t' + +substRanksCt :: (MonadTypeChecker m) => CtTy SComp -> SubstT m (CtTy ()) +substRanksCt (CtEq r t1 t2) = + CtEq + <$> traverse substRanksType r + <*> substRanksType t1 + <*> substRanksType t2 + +substRanksTyVarInfo :: (MonadTypeChecker m) => TyVarInfo SComp -> SubstT m (TyVarInfo ()) +substRanksTyVarInfo (TyVarFree loc l) = pure $ TyVarFree loc l +substRanksTyVarInfo (TyVarPrim loc ts) = pure $ TyVarPrim loc ts +substRanksTyVarInfo (TyVarRecord loc fs) = + TyVarRecord loc <$> traverse substRanksType fs +substRanksTyVarInfo (TyVarSum loc cs) = + TyVarSum loc <$> traverse (traverse substRanksType) cs + +substRanksTyVars :: (MonadTypeChecker m) => TyVars SComp -> SubstT m (TyVars ()) +substRanksTyVars = traverse $ \(lvl, tv) -> (lvl,) <$> substRanksTyVarInfo tv updAM :: RankMap -> Exp -> Exp updAM rank_map e = diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index eee234a221..297922f12e 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -54,8 +54,9 @@ import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Rank +import Language.Futhark.TypeChecker.TySolve hiding (Type) import Language.Futhark.TypeChecker.Types -import Language.Futhark.TypeChecker.Unify (Level, mkUsage) +import Language.Futhark.TypeChecker.Unify (mkUsage) import Prelude hiding (mod) data Inferred t @@ -66,6 +67,13 @@ instance Functor Inferred where fmap _ NoneInferred = NoneInferred fmap f (Ascribed t) = Ascribed (f t) +type Type = CtType SComp + +-- | Careful when using this on something that already has an SComp +-- size: it will throw away information by converting them to SDim. +toType :: TypeBase Size u -> TypeBase SComp u +toType = first (const SDim) + data ValBinding = BoundV [TypeParam] Type | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) @@ -97,9 +105,9 @@ data TermEnv = TermEnv -- type names. This is distinct from the usual counter we use for -- generating unique names, as these will be user-visible. data TermState = TermState - { termConstraints :: Constraints, + { termConstraints :: [CtTy SComp], termAM :: [CtAM], - termTyVars :: TyVars, + termTyVars :: TyVars SComp, termTyParams :: TyParams, termCounter :: !Int, termWarnings :: Warnings, @@ -212,7 +220,7 @@ incCounter = do tyVarType :: u -> TyVar -> TypeBase dim u tyVarType u v = Scalar $ TypeVar u (qualName v) [] -newTyVarWith :: Name -> TyVarInfo -> TermM TyVar +newTyVarWith :: Name -> TyVarInfo SComp -> TermM TyVar newTyVarWith desc info = do i <- incCounter v <- newID $ mkTypeVarName desc i @@ -286,10 +294,10 @@ asType t = do expType :: Exp -> TermM Type expType = asType . typeOf -- NOTE: Only place you should use typeOf. -addCt :: Ct -> TermM () +addCt :: CtTy SComp -> TermM () addCt ct = modify $ \s -> s {termConstraints = ct : termConstraints s} -ctEq :: Reason -> TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () +ctEq :: Reason (CtType SComp) -> TypeBase SComp u1 -> TypeBase SComp u2 -> TermM () ctEq reason t1 t2 = -- As a minor optimisation, do not add constraint if the types are -- equal. @@ -298,7 +306,7 @@ ctEq reason t1 t2 = t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness -ctAM :: Reason -> SVar -> SVar -> Shape SComp -> TermM () +ctAM :: Reason (CtType SComp) -> SVar -> SVar -> Shape SComp -> TermM () ctAM reason r m f = modify $ \s -> s {termAM = ct : termAM s} where diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs new file mode 100644 index 0000000000..058303c43c --- /dev/null +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -0,0 +1,668 @@ +module Language.Futhark.TypeChecker.TySolve + ( Type, + TyParams, + Solution, + UnconTyVar, + solve, + ) +where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.State +import Data.Bifunctor +import Data.List qualified as L +import Data.Loc +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.TypeChecker.Constraints +import Language.Futhark.TypeChecker.Error +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) +import Language.Futhark.TypeChecker.Types (substTyVars) + +-- | The type representation used by the constraint solver. Agnostic +-- to sizes. +type Type = CtType () + +data TyVarSol + = -- | Has been substituted with this. + TyVarSol Type + | -- | Is an explicit (rigid) type parameter in the source program. + TyVarParam Level Liftedness Loc + | -- | Not substituted yet; has this constraint. + TyVarUnsol (TyVarInfo ()) + deriving (Show) + +newtype SolverState = SolverState + { -- | Left means linked to this other type variable. + solverTyVars :: M.Map TyVar (Either VName TyVarSol) + } + +initialState :: TyParams -> TyVars () -> SolverState +initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars + where + f (_lvl, info) = Right $ TyVarUnsol info + g (lvl, l, loc) = Right $ TyVarParam lvl l loc + +substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase () u) +substTyVar m v = + case M.lookup v m of + Just (Left v') -> substTyVar m v' + Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right TyVarParam {}) -> Nothing + Just (Right (TyVarUnsol {})) -> Nothing + Nothing -> Nothing + +maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) +maybeLookupTyVar orig = do + tyvars <- gets solverTyVars + let f v = case M.lookup v tyvars of + Nothing -> pure Nothing + Just (Left v') -> f v' + Just (Right info) -> pure $ Just info + f orig + +lookupTyVar :: TyVar -> SolveM (Either (TyVarInfo ()) Type) +lookupTyVar orig = + maybe bad unpack <$> maybeLookupTyVar orig + where + bad = error $ "Unknown tyvar: " <> prettyNameString orig + unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig + unpack (TyVarSol t) = Right t + unpack (TyVarUnsol info) = Left info + +-- | Variable must be flexible. +lookupTyVarInfo :: TyVar -> SolveM (TyVarInfo ()) +lookupTyVarInfo v = do + r <- lookupTyVar v + case r of + Left info -> pure info + Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v + +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + +-- | A solution maps a type variable to its substitution. This +-- substitution is complete, in the sense there are no right-hand +-- sides that contain a type variable. +type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) + +-- | An unconstrained type variable comprises a name and (ironically) +-- a constraint on how it can be instantiated. +type UnconTyVar = (VName, Liftedness) + +typeVar :: (Monoid u) => VName -> TypeBase dim u +typeVar v = Scalar $ TypeVar mempty (qualName v) [] + +solution :: SolverState -> ([UnconTyVar], Solution) +solution s = + ( mapMaybe unconstrained $ M.toList $ solverTyVars s, + M.mapMaybe mkSubst $ solverTyVars s + ) + where + mkSubst (Right (TyVarSol t)) = + Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t + mkSubst (Left v') = + Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ + mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts + mkSubst _ = Nothing + + unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) + unconstrained _ = Nothing + +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) + +-- Try to substitute as much information as we have. +enrichType :: Type -> SolveM Type +enrichType t = do + s <- get + pure $ substTyVars (substTyVar (solverTyVars s)) t + +typeError :: Loc -> Notes -> Doc () -> SolveM () +typeError loc notes msg = + throwError $ TypeError loc notes msg + +occursCheck :: Reason Type -> VName -> Type -> SolveM () +occursCheck reason v tp = do + vars <- gets solverTyVars + let tp' = substTyVars (substTyVar vars) tp + when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ + "Occurs check: cannot instantiate" + <+> prettyName v + <+> "with" + <+> pretty tp + <> "." + +unifySharedConstructors :: + Reason Type -> + BreadCrumbs -> + M.Map Name [Type] -> + M.Map Name [Type] -> + SolveM () +unifySharedConstructors reason bcs cs1 cs2 = + forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> + if length ts1 == length ts2 + then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 + else + typeError (locOf reason) mempty $ + "Cannot unify type with constructor" + indent 2 (pretty (Sum (M.singleton c ts1))) + "with type of constructor" + indent 2 (pretty (Sum (M.singleton c ts2))) + "because they differ in arity." + +unifySharedFields :: + Reason Type -> + BreadCrumbs -> + M.Map Name Type -> + M.Map Name Type -> + SolveM () +unifySharedFields reason bcs fs1 fs2 = + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> + solveEq reason (matchingField f <> bcs) ts1 ts2 + +scopeViolation :: Reason Type -> VName -> Type -> VName -> SolveM () +scopeViolation reason v1 ty v2 = + typeError (locOf reason) mempty $ + "Cannot unify type" + indent 2 (pretty ty) + "with" + <+> dquotes (prettyName v1) + <+> "(scope violation)." + "This is because" + <+> dquotes (prettyName v2) + <+> "is rigidly bound in a deeper scope." + +cannotUnify :: + Reason Type -> + Notes -> + BreadCrumbs -> + Type -> + Type -> + SolveM () +cannotUnify reason notes bcs t1 t2 = do + t1' <- enrichType t1 + t2' <- enrichType t2 + case reason of + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonRetType loc expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ "Function body does not have expected type.", + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonApply loc f e expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ header, + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + where + header = + case f of + (Nothing, _) -> + "Cannot apply function to" + <+> dquotes (shorten $ group $ pretty e) + <> " (invalid type)." + (Just fname, _) -> + "Cannot apply" + <+> dquotes (pretty fname) + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> " (invalid type)." + ReasonApplySplit loc (fname, 0) _ ftype -> + typeError loc notes $ + stack + [ "Cannot apply" + <+> fname' + <+> "as function, as it has non-function type:" + indent 2 (align $ pretty ftype) + ] + where + fname' = maybe "expression" (dquotes . pretty) fname + ReasonApplySplit loc (fname, i) e _ -> + typeError loc notes $ + stack + [ "Cannot apply" + <+> fname' + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> ".", + "Function accepts only" <+> pretty i <+> "arguments." + ] + where + fname' = maybe "expression" (dquotes . pretty) fname + ReasonBranches loc former latter -> do + former' <- enrichType former + latter' <- enrichType latter + typeError loc notes . stack $ + [ "Branches differ in type.", + "Former:" <+> pretty former', + "Latter:" <+> pretty latter' + ] + +-- Precondition: 'v' is currently flexible. +subTyVar :: Reason Type -> BreadCrumbs -> VName -> Type -> SolveM () +subTyVar reason bcs v t = do + occursCheck reason v t + v_info <- gets $ M.lookup v . solverTyVars + + -- Set a solution for v, then update info for t in case v has any + -- odd constraints. + setInfo v (TyVarSol t) + + case (v_info, t) of + (Just (Right (TyVarUnsol TyVarFree {})), _) -> + pure () + ( Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), + _ + ) -> + if t `elem` map (Scalar . Prim) v_pts + then pure () + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot instance type that must be one of" + indent 2 (pretty v_pts) + "with" + indent 2 (pretty t) + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), + Scalar (Sum cs2) + ) -> + if all (`elem` M.keys cs2) (M.keys cs1) + then unifySharedConstructors reason bcs cs1 cs2 + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot match type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) + "with type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) + unsharedConstructorsMsg cs1 cs2 + ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), + _ + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty t) + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), + Scalar (Record fs2) + ) -> + if all (`elem` M.keys fs2) (M.keys fs1) + then unifySharedFields reason bcs fs1 fs2 + else + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with record type" + indent 2 (pretty (Record fs2)) + ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), + _ + ) -> + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty t) + -- + -- Internal error cases + (Just (Right TyVarSol {}), _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just (Right TyVarParam {}), _) -> + error $ "Cannot substitute type parameter: " <> prettyNameString v + (Just Left {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "subTyVar: Nothing v: " <> prettyNameString v + +-- Precondition: 'v' and 't' are both currently flexible. +-- +-- The purpose of this function is to combine the partial knowledge we +-- may have about these two type variables. +unionTyVars :: Reason Type -> BreadCrumbs -> VName -> VName -> SolveM () +unionTyVars reason bcs v t = do + v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars + t_info <- lookupTyVarInfo t + + -- Insert the link from v to t, and then update the info of t based + -- on the existing info of v and t. + setLink v t + + case (v_info, t_info) of + ( TyVarUnsol (TyVarFree _ v_l), + TyVarFree t_loc t_l + ) + | v_l /= t_l -> + setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) + -- When either is completely unconstrained. + (TyVarUnsol TyVarFree {}, _) -> + pure () + ( TyVarUnsol info, + TyVarFree {} + ) -> + setInfo t (TyVarUnsol info) + -- + -- TyVarPrim cases + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarPrim t_loc t_pts + ) -> + let pts = L.intersect v_pts t_pts + in if null pts + then + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be one of" + indent 2 (pretty t_pts) + else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarRecord {} + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be a record." + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarSum {} + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be sum." + -- + -- TyVarSum cases + ( TyVarUnsol (TyVarSum _ cs1), + TyVarSum loc cs2 + ) -> do + unifySharedConstructors reason bcs cs1 cs2 + let cs3 = cs1 <> cs2 + setInfo t (TyVarUnsol (TyVarSum loc cs3)) + ( TyVarUnsol TyVarSum {}, + TyVarPrim _ pts + ) -> + typeError (locOf reason) mempty $ + "A sum type cannot be one of" + indent 2 (pretty pts) + ( TyVarUnsol (TyVarSum _ cs1), + TyVarRecord _ fs + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Scalar (Record fs))) + -- + -- TyVarRecord cases + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarRecord loc fs2 + ) -> do + unifySharedFields reason bcs fs1 fs2 + let fs3 = fs1 <> fs2 + setInfo t (TyVarUnsol (TyVarRecord loc fs3)) + ( TyVarUnsol TyVarRecord {}, + TyVarPrim _ pts + ) -> + typeError (locOf reason) mempty $ + "A record type cannot be one of" + indent 2 (pretty pts) + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarSum _ cs + ) -> + typeError (locOf reason) mempty $ + "Cannot unify record type" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty (Scalar (Sum cs))) + -- + -- Internal error cases + (TyVarSol {}, _) -> + alreadySolved + (TyVarParam {}, _) -> + isParam + where + unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v + alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v + alreadySolved = error $ "Type variable already solved: " <> prettyNameString v + isParam = error $ "Type name is a type parameter: " <> prettyNameString v + +unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a +unsharedConstructorsMsg cs1 cs2 = + "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." + where + missing = + filter (`notElem` M.keys cs1) (M.keys cs2) + ++ filter (`notElem` M.keys cs2) (M.keys cs1) + +-- Unify at the root, emitting new equalities that must hold. +unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] +unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) + | pt1 == pt2 = Right [] +unify + (Scalar (TypeVar _ (QualName _ v1) targs1)) + (Scalar (TypeVar _ (QualName _ v2) targs2)) + | v1 == v2 = + Right $ mapMaybe f $ zip targs1 targs2 + where + f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) + f _ = Nothing +unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = + Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness +unify (Scalar (Record fs1)) (Scalar (Record fs2)) + | M.keys fs1 == M.keys fs2 = + Right $ + map (first matchingField) $ + M.toList $ + M.intersectionWith (,) fs1 fs2 + | Just n1 <- length <$> areTupleFields fs1, + Just n2 <- length <$> areTupleFields fs2, + n1 /= n2 = + Left $ + "Tuples have" + <+> pretty n1 + <+> "and" + <+> pretty n2 + <+> "elements respectively." + | otherwise = + let missing = + filter (`notElem` M.keys fs1) (M.keys fs2) + <> filter (`notElem` M.keys fs2) (M.keys fs1) + in Left $ + "unshared fields:" <+> commasep (map pretty missing) <> "." +unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) + | M.keys cs1 == M.keys cs2 = + fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do + if length ts1 == length ts2 + then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 + else Left mempty + | otherwise = + Left $ unsharedConstructorsMsg cs1 cs2 + where + cs' = M.toList $ M.intersectionWith (,) cs1 cs2 +unify t1 t2 + | Just t1' <- peelArray 1 t1, + Just t2' <- peelArray 1 t2 = + Right [(mempty, (t1', t2'))] +unify _ _ = Left mempty + +solveEq :: Reason Type -> BreadCrumbs -> Type -> Type -> SolveM () +solveEq reason obcs orig_t1 orig_t2 = do + solveCt' (obcs, (orig_t1, orig_t2)) + where + solveCt' (bcs, (t1, t2)) = do + tyvars <- gets solverTyVars + let flexible v = case M.lookup v tyvars of + Just (Left v') -> flexible v' + Just (Right (TyVarUnsol _)) -> True + Just (Right TyVarSol {}) -> False + Just (Right TyVarParam {}) -> False + Nothing -> False + sub t@(Scalar (TypeVar u (QualName [] v) [])) = + case M.lookup v tyvars of + Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) + Just (Right (TyVarSol t')) -> sub t' + _ -> t + sub t = t + case (sub t1, sub t2) of + ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), + t2'@(Scalar (TypeVar _ (QualName [] v2) [])) + ) + | v1 == v2 -> pure () + | otherwise -> + case (flexible v1, flexible v2) of + (False, False) -> cannotUnify reason mempty bcs t1 t2 + (True, False) -> subTyVar reason bcs v1 t2' + (False, True) -> subTyVar reason bcs v2 t1' + (True, True) -> unionTyVars reason bcs v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), t2') + | flexible v1 -> subTyVar reason bcs v1 t2' + (t1', Scalar (TypeVar _ (QualName [] v2) [])) + | flexible v2 -> subTyVar reason bcs v2 t1' + (t1', t2') -> case unify t1' t2' of + Left details -> cannotUnify reason (aNote details) bcs t1' t2' + Right eqs -> mapM_ solveCt' eqs + +solveCt :: CtTy () -> SolveM () +solveCt ct = + case ct of + CtEq reason t1 t2 -> solveEq reason mempty t1 t2 + +scopeCheck :: Reason Type -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = do + mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _ _)) + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + Just (Right (TyVarSol ty')) -> + mapM_ check $ typeVars ty' + _ -> pure () + +-- If a type variable has a liftedness constraint, we propagate that +-- constraint to its solution. The actual checking for correct usage +-- is done later. +liftednessCheck :: Liftedness -> Type -> SolveM () +liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do + v_info <- maybeLookupTyVar v + case v_info of + Nothing -> + -- Is an opaque type. + pure () + Just (TyVarSol v_ty) -> + liftednessCheck l v_ty + Just TyVarParam {} -> pure () + Just (TyVarUnsol (TyVarFree loc v_l)) + | l /= v_l -> + setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) + Just TyVarUnsol {} -> pure () +liftednessCheck _ (Scalar Prim {}) = pure () +liftednessCheck Lifted _ = pure () +liftednessCheck _ Array {} = pure () +liftednessCheck _ (Scalar Arrow {}) = pure () +liftednessCheck l (Scalar (Record fs)) = + mapM_ (liftednessCheck l) fs +liftednessCheck l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck l) cs +liftednessCheck _ (Scalar TypeVar {}) = pure () + +solveTyVar :: (VName, (Level, TyVarInfo ())) -> SolveM () +solveTyVar (tv, (_, TyVarRecord loc fs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left _ -> + typeError loc mempty $ + "Type" + <+> prettyName tv + <+> "is ambiguous." + "Must be a record with fields" + indent 2 (pretty (Scalar (Record fs1))) + Right _ -> + pure () +solveTyVar (tv, (_, TyVarSum loc cs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left _ -> + typeError loc mempty $ + "Type is ambiguous." + "Must be a sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) + Right _ -> pure () +solveTyVar (tv, (lvl, TyVarFree loc l)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right ty -> do + scopeCheck (Reason loc) tv lvl ty + liftednessCheck l ty + _ -> pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right (Scalar (Prim ty)) + | [ty] == pts -> + setInfo tv $ TyVarSol $ Scalar $ Prim ty + Right ty + | ty `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + typeError loc mempty $ + "Numeric constant inferred to be of type" + indent 2 (align (pretty ty)) + "which is not possible." + _ -> pure () + +solve :: + [CtTy ()] -> + TyParams -> + TyVars () -> + Either TypeError ([UnconTyVar], Solution) +solve constraints typarams tyvars = + second solution + . runExcept + . flip execStateT (initialState typarams tyvars) + . runSolveM + $ do + mapM_ solveCt constraints + mapM_ solveTyVar (M.toList tyvars) +{-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 5a5fb42cd8..bde1fbe480 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -36,6 +36,7 @@ import Futhark.Util (topologicalSort) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Traversals +import Language.Futhark.TypeChecker.Constraints (Level) import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Types @@ -60,11 +61,6 @@ instance Pretty Usage where instance Located Usage where locOf (Usage _ loc) = locOf loc --- | The level at which a type variable is bound. Higher means --- deeper. We can only unify a type variable at level @i@ with a type --- @t@ if all type names that occur in @t@ are at most at level @i@. -type Level = Int - -- | A constraint on a yet-ambiguous type variable. data Constraint = NoConstraint Liftedness Usage From ec72dcee7e23abf742addfce8257771365860762 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 28 Jan 2025 22:52:04 +0100 Subject: [PATCH 275/296] Some docs. --- src/Language/Futhark/TypeChecker/TySolve.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 058303c43c..8cecb4afd0 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -1,6 +1,6 @@ +-- | The constraint solver for type equality constraints. module Language.Futhark.TypeChecker.TySolve ( Type, - TyParams, Solution, UnconTyVar, solve, @@ -24,9 +24,10 @@ import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) import Language.Futhark.TypeChecker.Types (substTyVars) -- | The type representation used by the constraint solver. Agnostic --- to sizes. +-- to sizes and uniqueness. type Type = CtType () +-- | A (partial) solution for a type variable. data TyVarSol = -- | Has been substituted with this. TyVarSol Type @@ -652,6 +653,8 @@ solveTyVar (tv, (_, TyVarPrim loc pts)) = do "which is not possible." _ -> pure () +-- | Solve type constraints, producing either an error or a solution, +-- alongside a list of unconstrained type variables. solve :: [CtTy ()] -> TyParams -> From e8d78e788238478f33813c71a0626d6259dc251a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 02:07:35 +0100 Subject: [PATCH 276/296] Slightly simplify type checking of patterns. --- .../Futhark/TypeChecker/Constraints.hs | 2 +- src/Language/Futhark/TypeChecker/Terms2.hs | 46 ++++++++----------- tests/ascription0.fut | 8 ++-- tests/issue1783.fut | 2 +- tests/sumtypes/sumtype46.fut | 2 +- tests/sumtypes/sumtype47.fut | 2 +- 6 files changed, 28 insertions(+), 34 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index a4c4971afa..e5860e4369 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -48,7 +48,7 @@ data Reason t = -- | No particular reason. Reason Loc | -- | Arising from pattern match. - ReasonPatMatch Loc (PatBase NoInfo VName ParamType) t + ReasonPatMatch Loc (PatBase NoInfo VName StructType) t | -- | Arising from explicit ascription. ReasonAscription Loc t t | ReasonRetType Loc t t diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 297922f12e..25b9ebe546 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -59,14 +59,6 @@ import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify (mkUsage) import Prelude hiding (mod) -data Inferred t - = NoneInferred - | Ascribed t - -instance Functor Inferred where - fmap _ NoneInferred = NoneInferred - fmap f (Ascribed t) = Ascribed (f t) - type Type = CtType SComp -- | Careful when using this on something that already has an SComp @@ -482,6 +474,10 @@ checkSizeExp' e = do ctEq (Reason (locOf e)) e_t (Scalar (Prim (Signed Int64))) pure e' +data Inferred t + = NoneInferred + | Ascribed t + checkPat' :: PatBase NoInfo VName ParamType -> Inferred (TypeBase SComp Diet) -> @@ -502,19 +498,18 @@ checkPat' (Wildcard _ loc) (Ascribed t) = do checkPat' (Wildcard NoInfo loc) NoneInferred = do t <- newType loc Lifted "t" Observe pure $ Wildcard (Info t) loc -checkPat' (TuplePat ps loc) (Ascribed t) +checkPat' p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = TuplePat <$> zipWithM checkPat' ps (map Ascribed ts) <*> pure loc - | otherwise = do - ps_tvs <- replicateM (length ps) (newTyVar loc Lifted "t") - ctEq - (ReasonPatMatch (locOf loc) (TuplePat ps loc) (toStruct t)) - (Scalar (tupleRecord $ map (tyVarType NoUniqueness) ps_tvs)) - t - TuplePat <$> zipWithM checkPat' ps (map (Ascribed . tyVarType Observe) ps_tvs) <*> pure loc + | otherwise = + typeError loc mempty $ + "Pattern" + indent 2 (pretty p) + "cannot match ascripted type" + pretty t checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) _ @@ -601,18 +596,16 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps t <- newTypeWithConstr loc "t" Observe n =<< mapM (asType . patternType) ps' - t' <- asStructType t - pure $ PatConstr n (Info $ toParam Observe t') ps' loc + pure $ PatConstr n (Info $ toParam Observe t) ps' loc checkPat :: PatBase NoInfo VName (TypeBase Size u) -> - Inferred Type -> (Pat ParamType -> TermM a) -> TermM a -checkPat p t m = - m =<< checkPat' (fmap (toParam Observe) p) (fmap (fmap (const Observe)) t) +checkPat p m = + m =<< checkPat' (fmap (toParam Observe) p) NoneInferred --- | Bind @let@-bound sizes. This is usually followed by 'bindletPat' +-- | Bind @let@-bound sizes. This is usually followed by 'bindLetPat' -- immediately afterwards. bindSizes :: [SizeBinder VName] -> TermM a -> TermM a bindSizes [] m = m -- Minor optimisation. @@ -627,9 +620,10 @@ bindLetPat :: (Pat ParamType -> TermM a) -> TermM a bindLetPat p t m = do - checkPat p (Ascribed t) $ \p' -> - bind (patIdents (fmap toStruct p')) $ - m p' + checkPat p $ \p' -> do + pt <- asType $ patternType p' + ctEq (ReasonPatMatch (locOf p) (fmap toStruct p) t) pt t + bind (patIdents (fmap toStruct p')) $ m p' typeParamIdent :: TypeParam -> Maybe (Ident StructType) typeParamIdent (TypeParamDim v loc) = @@ -674,7 +668,7 @@ bindParams :: TermM a bindParams tps orig_ps m = bindTypeParams tps $ do let descend ps' (p : ps) = - checkPat p NoneInferred $ \p' -> + checkPat p $ \p' -> bind (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps descend ps' [] = m $ reverse ps' diff --git a/tests/ascription0.fut b/tests/ascription0.fut index 8c3a50e026..92317ea521 100644 --- a/tests/ascription0.fut +++ b/tests/ascription0.fut @@ -1,8 +1,8 @@ -- Make sure type errors due to invalid type ascriptions are caught. -- -- == --- error: match +-- error: ascription -def main(x: i32, y:i32): (bool,bool) = - let (((a): i32), b: i32) : (bool,bool) = (x,y) - in (a,b) +def main (x: i32, y: i32) : (bool, bool) = + let (((a): i32), b: i32): (bool, bool) = (x, y) + in (a, b) diff --git a/tests/issue1783.fut b/tests/issue1783.fut index 0b8e197910..41408ccb7f 100644 --- a/tests/issue1783.fut +++ b/tests/issue1783.fut @@ -1,5 +1,5 @@ -- == --- error: constructor arguments +-- error: differ in arity type surface = #asphere {curvature: f64} diff --git a/tests/sumtypes/sumtype46.fut b/tests/sumtypes/sumtype46.fut index a56af88750..8abc3278fb 100644 --- a/tests/sumtypes/sumtype46.fut +++ b/tests/sumtypes/sumtype46.fut @@ -1,5 +1,5 @@ -- == --- error: 0 constructor arguments +-- error: differ in arity type t = #foo f64 diff --git a/tests/sumtypes/sumtype47.fut b/tests/sumtypes/sumtype47.fut index 09a51f3faa..c0039bfc43 100644 --- a/tests/sumtypes/sumtype47.fut +++ b/tests/sumtypes/sumtype47.fut @@ -1,5 +1,5 @@ -- == --- error: 2 constructor arguments +-- error: differ in arity type t = #foo f64 From 7d06fbdfaf1ff8d9c369783d0ff2731266afe61b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 02:18:55 +0100 Subject: [PATCH 277/296] Further simplify. --- src/Language/Futhark/TypeChecker/Terms2.hs | 39 ++++++++-------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 25b9ebe546..a00fc31616 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -509,7 +509,7 @@ checkPat' p@(TuplePat ps loc) (Ascribed t) "Pattern" indent 2 (pretty p) "cannot match ascripted type" - pretty t + indent 2 (pretty t) checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc checkPat' p@(RecordPat p_fs loc) _ @@ -529,12 +529,11 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) map fst t_fs' == map (unLoc . fst) p_fs' = RecordPat <$> zipWithM check p_fs' t_fs' <*> pure loc | otherwise = do - p_fs' <- - traverse (const $ newType loc Lifted "t" NoUniqueness) $ - M.fromList $ - map (first unLoc) p_fs - ctEq (Reason (locOf loc)) (Scalar (Record p_fs')) t - checkPat' p $ Ascribed $ Observe <$ Scalar (Record p_fs') + typeError loc mempty $ + "Pattern" + indent 2 (pretty p) + "cannot match ascripted type" + indent 2 (pretty t) where check (L f_loc f, p_f) (_, t_f) = (L f_loc f,) <$> checkPat' p_f (Ascribed t_f) @@ -572,27 +571,17 @@ checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc pure $ PatLit l (Info t') loc checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) - | Just ts <- M.lookup n cs = do - when (length ps /= length ts) $ - typeError loc mempty $ - "Pattern #" - <> pretty n - <> " expects" - <+> pretty (length ps) - <+> "constructor arguments, but type provides" - <+> pretty (length ts) - <+> "arguments." + | Just ts <- M.lookup n cs, + length ps == length ts = do ps' <- zipWithM checkPat' ps $ map Ascribed ts cs' <- traverse (mapM asStructType) cs pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc -checkPat' (PatConstr n NoInfo ps loc) (Ascribed t) = do - ps' <- forM ps $ \p -> do - p_t <- newType (srclocOf p) Lifted "t" Observe - checkPat' p $ Ascribed p_t - t' <- newTypeWithConstr loc "t" Observe n =<< mapM (asType . patternType) ps' - ctEq (Reason (locOf loc)) t' t - t'' <- asStructType t' - pure $ PatConstr n (Info $ toParam Observe t'') ps' loc +checkPat' p@(PatConstr {}) (Ascribed t) = + typeError (locOf p) mempty $ + "Pattern" + indent 2 (pretty p) + "cannot match ascripted type" + indent 2 (pretty t) checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps t <- newTypeWithConstr loc "t" Observe n =<< mapM (asType . patternType) ps' From ddcb6cec2c6e07ab52644073b6bcebebd5653e0e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 10:40:56 +0100 Subject: [PATCH 278/296] Simplify. --- src/Language/Futhark/TypeChecker/Rank.hs | 41 ++++++++++++------------ 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index c471395422..60ae6f0a07 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -127,14 +127,14 @@ binVar sv = do modify $ \s -> s { rankBinVars = M.insert sv bv $ rankBinVars s, - rankConstraints = rankConstraints s ++ [bin bv, var bv ~<=~ var sv] + rankConstraints = [bin bv, var bv ~<=~ var sv] <> rankConstraints s } pure bv Just bv -> pure bv addConstraints :: [Constraint] -> RankM () addConstraints cs = - modify $ \s -> s {rankConstraints = rankConstraints s ++ cs} + modify $ \s -> s {rankConstraints = cs <> rankConstraints s} addConstraint :: Constraint -> RankM () addConstraint = addConstraints . pure @@ -194,14 +194,14 @@ ambigCheckLinearProg :: LinearProg -> (Int, Map VName Int) -> LinearProg ambigCheckLinearProg prog (opt, ranks) = prog { constraints = - constraints prog - -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html - ++ [ lsum (var <$> M.keys one_bins) - ~-~ lsum (var <$> M.keys zero_bins) - ~<=~ constant (fromIntegral $ length one_bins) - ~-~ constant 1, - objective prog ~==~ constant (fromIntegral opt) - ] + -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html + [ lsum (var <$> M.keys one_bins) + ~-~ lsum (var <$> M.keys zero_bins) + ~<=~ constant (fromIntegral $ length one_bins) + ~-~ constant 1, + objective prog ~==~ constant (fromIntegral opt) + ] + ++ constraints prog } where -- We really need to track which variables are binary in the LinearProg @@ -376,6 +376,11 @@ data SubstState = SubstState instance MonadTrans SubstT where lift = SubstT . lift . lift +rankToShape :: (Monad m) => VName -> SubstT m (Shape ()) +rankToShape x = do + rs <- asks envRanks + pure $ Shape $ replicate (fromJust $ rs M.!? x) () + newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar newTyVar t = do t' <- lift $ newTypeName (baseName t) @@ -385,20 +390,14 @@ newTyVar t = do s { substNewVars = M.insert t t' $ substNewVars s, substNewCts = - substNewCts s - ++ [ CtEq - (Reason loc) - (Scalar (TypeVar mempty (QualName [] t) [])) - (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) - ] + CtEq + (Reason loc) + (Scalar (TypeVar mempty (QualName [] t) [])) + (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) + : substNewCts s } pure t' -rankToShape :: (Monad m) => VName -> SubstT m (Shape ()) -rankToShape x = do - rs <- asks envRanks - pure $ Shape $ replicate (fromJust $ rs M.!? x) () - addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () addRankInfo t = do rs <- asks envRanks From d0552ec9c5abebfa8f8d0364583bb05e69f64444 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 10:41:10 +0100 Subject: [PATCH 279/296] Remove dead code. --- src/Language/Futhark/TypeChecker/Rank.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 60ae6f0a07..da2071a116 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -172,8 +172,6 @@ mkLinearProg cs cs_am tyVars = LP.LinearProg { optType = Minimize, objective = rankObj finalState, - -- let shape_vars = M.keys $ rankBinVars finalState - -- in foldr (\sv s -> var sv ~+~ s) (constant 0) shape_vars, constraints = rankConstraints finalState } where From a89b63e5f09e10fc1498bd76fa4f0cfe5d9d225f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 10:46:52 +0100 Subject: [PATCH 280/296] More cleanup. --- src/Language/Futhark/TypeChecker/Constraints.hs | 3 +-- src/Language/Futhark/TypeChecker/Rank.hs | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index e5860e4369..9a75349517 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -1,5 +1,4 @@ --- | Constraint solver for solving type equations produced --- post-AUTOMAP. +-- | Constraints produced (and solved) by the type checker. module Language.Futhark.TypeChecker.Constraints ( Reason (..), SVar, diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index da2071a116..b82de905f5 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -1,7 +1,6 @@ module Language.Futhark.TypeChecker.Rank ( rankAnalysis, rankAnalysis1, - CtAM (..), ) where From bad8db7056d1127779f6a1d984280bea7209d83e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 12:01:37 +0100 Subject: [PATCH 281/296] Further simplify. --- src/Language/Futhark/TypeChecker/Terms2.hs | 43 +++++++++++----------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index a00fc31616..134c4aae29 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -474,27 +474,29 @@ checkSizeExp' e = do ctEq (Reason (locOf e)) e_t (Scalar (Prim (Signed Int64))) pure e' -data Inferred t +-- | During type checking a pattern, we might find an explicit +-- ascription. These contain complete type information (although they +-- must of course still be checked against what remains of the +-- pattern). +data Inferred = NoneInferred - | Ascribed t + | Ascribed ParamType checkPat' :: PatBase NoInfo VName ParamType -> - Inferred (TypeBase SComp Diet) -> + Inferred -> TermM (Pat ParamType) checkPat' (PatParens p loc) t = PatParens <$> checkPat' p t <*> pure loc checkPat' (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' p t <*> pure loc -checkPat' (Id name NoInfo loc) (Ascribed t) = do - t' <- asStructType t - pure $ Id name (Info t') loc +checkPat' (Id name NoInfo loc) (Ascribed t) = + pure $ Id name (Info t) loc checkPat' (Id name NoInfo loc) NoneInferred = do t <- newType loc Lifted "t" Observe pure $ Id name (Info t) loc checkPat' (Wildcard _ loc) (Ascribed t) = do - t' <- asStructType t - pure $ Wildcard (Info t') loc + pure $ Wildcard (Info t) loc checkPat' (Wildcard NoInfo loc) NoneInferred = do t <- newType loc Lifted "t" Observe pure $ Wildcard (Info t) loc @@ -508,7 +510,7 @@ checkPat' p@(TuplePat ps loc) (Ascribed t) typeError loc mempty $ "Pattern" indent 2 (pretty p) - "cannot match ascripted type" + "cannot match ascribed type" indent 2 (pretty t) checkPat' (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (`checkPat'` NoneInferred) ps <*> pure loc @@ -532,7 +534,7 @@ checkPat' p@(RecordPat p_fs loc) (Ascribed t) typeError loc mempty $ "Pattern" indent 2 (pretty p) - "cannot match ascripted type" + "cannot match ascribed type" indent 2 (pretty t) where check (L f_loc f, p_f) (_, t_f) = @@ -544,16 +546,16 @@ checkPat' (RecordPat fs loc) NoneInferred = checkPat' (PatAscription p t loc) maybe_outer_t = do (t', _, RetType _ st, _) <- checkTypeExp checkSizeExp' t - -- Uniqueness kung fu to make the Monoid(mempty) instance give what - -- we expect. We should perhaps stop being so implicit. - st' <- asType $ resToParam st + let st' = resToParam st case maybe_outer_t of Ascribed outer_t -> do - ctEq - (ReasonAscription (locOf loc) (toStruct st') (toStruct outer_t)) - st' - outer_t + unless (toType st' == toType outer_t) $ + typeError loc mempty $ + "Ascribed type" + indent 2 (pretty st) + "cannot match outer ascribed type" + indent 2 (pretty outer_t) PatAscription <$> checkPat' p (Ascribed st') <*> pure t' @@ -565,7 +567,7 @@ checkPat' (PatAscription p t loc) maybe_outer_t = do <*> pure loc checkPat' (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc - ctEq (Reason (locOf loc)) (toType t') t + ctEq (Reason (locOf loc)) (toType t') (toType t) pure $ PatLit l (Info t') loc checkPat' (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc @@ -574,13 +576,12 @@ checkPat' (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) | Just ts <- M.lookup n cs, length ps == length ts = do ps' <- zipWithM checkPat' ps $ map Ascribed ts - cs' <- traverse (mapM asStructType) cs - pure $ PatConstr n (Info (Scalar (Sum cs'))) ps' loc + pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' p@(PatConstr {}) (Ascribed t) = typeError (locOf p) mempty $ "Pattern" indent 2 (pretty p) - "cannot match ascripted type" + "cannot match ascribed type" indent 2 (pretty t) checkPat' (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (`checkPat'` NoneInferred) ps From 840f1a1ce03081d89b64ecadeead6d6b4828f3de Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 29 Jan 2025 12:11:37 +0100 Subject: [PATCH 282/296] Improve nomenclature. --- futhark.cabal | 2 +- src/Language/Futhark/TypeChecker/Terms.hs | 23 +++++++++++--- .../{Terms2.hs => Terms/Unsized.hs} | 30 ++++++------------- src/Language/Futhark/TypeChecker/TySolve.hs | 2 +- 4 files changed, 30 insertions(+), 27 deletions(-) rename src/Language/Futhark/TypeChecker/{Terms2.hs => Terms/Unsized.hs} (97%) diff --git a/futhark.cabal b/futhark.cabal index a2e0ef6dcf..a3e8a85e31 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -433,10 +433,10 @@ library Language.Futhark.TypeChecker.Monad Language.Futhark.TypeChecker.Rank Language.Futhark.TypeChecker.Terms - Language.Futhark.TypeChecker.Terms2 Language.Futhark.TypeChecker.Terms.Loop Language.Futhark.TypeChecker.Terms.Monad Language.Futhark.TypeChecker.Terms.Pat + Language.Futhark.TypeChecker.Terms.Unsized Language.Futhark.TypeChecker.Types Language.Futhark.TypeChecker.Unify Language.Futhark.Warnings diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index b6009887ed..797dcb1b01 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -6,6 +6,21 @@ -- number of built-in language constructs, as well as uniqueness -- types. This is mostly done in an ad hoc way, and many programs -- will require the programmer to fall back on type annotations. +-- +-- The strategy is to split type checking into sveral (main) passes: +-- +-- 1) A size-agnostic pass implemented in +-- "Language.Futhark.TypeChecker.Terms.Unsized". +-- +-- 2) Pass (1) has given us a program where we know the types of +-- everything, but the sizes of nothing. Pass (2) then does +-- essentially size inference, with the benefit of already knowing the +-- full unsized type of everything. This is done using a syntax-driven +-- approach, similar to Algorithm W. +-- +-- 3) The program is then checked for violation of uniqueness +-- properties, which is implemented in +-- "Language.Futhark.TypeChecker.Consumption". module Language.Futhark.TypeChecker.Terms ( checkOneExp, checkSizeExp, @@ -37,7 +52,7 @@ import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Terms.Loop import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Terms.Pat -import Language.Futhark.TypeChecker.Terms2 qualified as Terms2 +import Language.Futhark.TypeChecker.Terms.Unsized qualified as Unsized import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) @@ -965,7 +980,7 @@ checkApply _ _ _ _ _ = -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) checkOneExp e = do - (maybe_tysubsts, e') <- Terms2.checkSingleExp e + (maybe_tysubsts, e') <- Unsized.checkSingleExp e case maybe_tysubsts of Left err -> throwError err Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do @@ -983,7 +998,7 @@ checkOneExp e = do -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp checkSizeExp e = do - (maybe_tysubsts, e') <- Terms2.checkSizeExp e + (maybe_tysubsts, e') <- Unsized.checkSizeExp e case maybe_tysubsts of Left err -> throwError err Right (_generalised, tysubsts) -> runTermTypeM checkExp tysubsts $ do @@ -1578,7 +1593,7 @@ checkFunDef :: Exp ) checkFunDef (fname, retdecl, tparams, params, body, loc) = - doChecks =<< Terms2.checkValDef (fname, retdecl, tparams, params, body, loc) + doChecks =<< Unsized.checkValDef (fname, retdecl, tparams, params, body, loc) where -- TODO: Print out the possibilities. (And also potentially eliminate --- some of the possibilities to disambiguate). diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms/Unsized.hs similarity index 97% rename from src/Language/Futhark/TypeChecker/Terms2.hs rename to src/Language/Futhark/TypeChecker/Terms/Unsized.hs index 134c4aae29..8c5f7a80cb 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Unsized.hs @@ -1,27 +1,15 @@ --- | A very WIP reimplementation of type checking of terms. +-- | Unsized type checking. -- --- The strategy is to split type checking into two (main) passes: --- --- 1) A size-agnostic pass that generates type constraints (type Ct) --- and AUTOMAP constraints (type CtAM) which are later solved offline --- to find a solution. This produces an AST where most of the type --- annotations are just references to type variables. Further, all the +-- This checker generates type constraints (type 'CtTy') and AUTOMAP +-- constraints (type 'CtAM') which are then solved to find a solution. +-- The result is a decorated AST where most of the type annotations +-- are just references to type variables. Further, all the -- size-specific annotations (e.g. existential sizes) just contain --- dummy values, such as empty lists. The constraints use a type --- representation where all dimensions are the same. However, we do --- try to store the sizes resulting from explicit type ascriptions - --- these cannot refer to inferred existentials, so it is safe to --- resolve them here. We don't do anything with this information, --- however. +-- dummy values, such as empty lists. -- --- 2) Pass (1) has given us a program where we know the types of --- everything, but the sizes of nothing. Pass (2) then does --- essentially size inference, much like the current/old type checker, --- but of course with the massive benefit of already knowing the full --- type of everything. This can be implemented using online constraint --- solving (as before), or perhaps a completely syntax-driven --- approach. -module Language.Futhark.TypeChecker.Terms2 +-- If Futhark had no fancy type system features, then this pass would +-- essentially be all you needed. +module Language.Futhark.TypeChecker.Terms.Unsized ( checkValDef, checkSingleExp, checkSizeExp, diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 8cecb4afd0..44b873f925 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -1,4 +1,4 @@ --- | The constraint solver for type equality constraints. +-- | The constraint solver for unsized type equality constraints. module Language.Futhark.TypeChecker.TySolve ( Type, Solution, From ddfd279acd237f3bf1b8610bbda9445cf1be5f25 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 6 Mar 2025 09:57:14 +0100 Subject: [PATCH 283/296] Start work on unit testing constraint solver. --- futhark.cabal | 1 + unittests/Language/Futhark/SyntaxTests.hs | 94 +++++++++++-------- .../Futhark/TypeChecker/TySolveTests.hs | 46 +++++++++ .../Language/Futhark/TypeCheckerTests.hs | 4 +- 4 files changed, 106 insertions(+), 39 deletions(-) create mode 100644 unittests/Language/Futhark/TypeChecker/TySolveTests.hs diff --git a/futhark.cabal b/futhark.cabal index a3e8a85e31..47b096c321 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -551,6 +551,7 @@ test-suite unit Language.Futhark.SemanticTests Language.Futhark.SyntaxTests Language.Futhark.TypeChecker.TypesTests + Language.Futhark.TypeChecker.TySolveTests Language.Futhark.TypeCheckerTests build-depends: QuickCheck >=2.8 diff --git a/unittests/Language/Futhark/SyntaxTests.hs b/unittests/Language/Futhark/SyntaxTests.hs index 4f5c80b2e8..61efc6306c 100644 --- a/unittests/Language/Futhark/SyntaxTests.hs +++ b/unittests/Language/Futhark/SyntaxTests.hs @@ -112,78 +112,84 @@ pPrimType = pUniqueness :: Parser Uniqueness pUniqueness = choice [lexeme "*" $> Unique, pure Nonunique] +pDim :: Parser d -> Parser d +pDim = brackets + pSize :: Parser Size pSize = - brackets $ - choice - [ flip sizeFromInteger mempty <$> lexeme L.decimal, - flip sizeFromName mempty <$> pQualName - ] + choice + [ flip sizeFromInteger mempty <$> lexeme L.decimal, + flip sizeFromName mempty <$> pQualName + ] -pScalarNonFun :: Parser (ScalarTypeBase Size Uniqueness) -pScalarNonFun = +pScalarNonFun :: Parser d -> Parser (ScalarTypeBase d Uniqueness) +pScalarNonFun pd = choice [ Prim <$> pPrimType, pTypeVar, - tupleRecord <$> parens (pType `sepBy` lexeme ","), + tupleRecord <$> parens (pType pd `sepBy` lexeme ","), Record . M.fromList <$> braces (pField `sepBy1` lexeme ",") ] where - pField = (,) <$> pName <* lexeme ":" <*> pType + pField = (,) <$> pName <* lexeme ":" <*> pType pd pTypeVar = TypeVar <$> pUniqueness <*> pQualName <*> many pTypeArg pTypeArg = choice - [ TypeArgDim <$> pSize, + [ TypeArgDim <$> pDim pd, TypeArgType . second (const NoUniqueness) <$> pTypeArgType ] pTypeArgType = choice [ Scalar . Prim <$> pPrimType, - parens pType + parens $ pType pd ] -pArrayType :: Parser ResType -pArrayType = +pArrayType :: Parser d -> Parser (TypeBase d Uniqueness) +pArrayType pd = Array <$> pUniqueness - <*> (Shape <$> some pSize) - <*> (second (const NoUniqueness) <$> pScalarNonFun) + <*> (Shape <$> some (pDim pd)) + <*> (second (const NoUniqueness) <$> pScalarNonFun pd) -pNonFunType :: Parser ResType -pNonFunType = +pNonFunType :: Parser d -> Parser (TypeBase d Uniqueness) +pNonFunType pd = choice - [ try pArrayType, - try $ parens pType, - Scalar <$> pScalarNonFun + [ try $ pArrayType pd, + try $ parens $ pType pd, + Scalar <$> pScalarNonFun pd ] -pScalarType :: Parser (ScalarTypeBase Size Uniqueness) -pScalarType = choice [try pFun, pScalarNonFun] +uniquenessToDiet :: Uniqueness -> Diet +uniquenessToDiet Unique = Consume +uniquenessToDiet Nonunique = Observe + +pScalarType :: Parser d -> Parser (ScalarTypeBase d Uniqueness) +pScalarType pd = choice [try pFun, pScalarNonFun pd] where pFun = - pParam <* lexeme "->" <*> pRetType + pParam <* lexeme "->" <*> pRetType pd pParam = choice [ try pNamedParam, do - t <- pNonFunType - pure $ Arrow Nonunique Unnamed (diet $ resToParam t) (toStruct t) + t <- pNonFunType pd + pure $ Arrow Nonunique Unnamed (diet $ second uniquenessToDiet t) (toStruct t) ] pNamedParam = parens $ do v <- pVName <* lexeme ":" - t <- pType - pure $ Arrow Nonunique (Named v) (diet $ resToParam t) (toStruct t) + t <- pType pd + pure $ Arrow Nonunique (Named v) (diet $ second uniquenessToDiet t) (toStruct t) -pRetType :: Parser ResRetType -pRetType = +pRetType :: Parser d -> Parser (RetTypeBase d Uniqueness) +pRetType pd = choice - [ lexeme "?" *> (RetType <$> some (brackets pVName) <* lexeme "." <*> pType), - RetType [] <$> pType + [ lexeme "?" *> (RetType <$> some (brackets pVName) <* lexeme "." <*> pType pd), + RetType [] <$> pType pd ] -pType :: Parser ResType -pType = - choice [try $ Scalar <$> pScalarType, pArrayType, parens pType] +pType :: Parser d -> Parser (TypeBase d Uniqueness) +pType pd = + choice [try $ Scalar <$> pScalarType pd, pArrayType pd, parens (pType pd)] fromStringParse :: Parser a -> String -> String -> a fromStringParse p what s = @@ -194,15 +200,27 @@ fromStringParse p what s = instance IsString (ScalarTypeBase Size NoUniqueness) where fromString = - fromStringParse (second (const NoUniqueness) <$> pScalarType) "ScalarType" + fromStringParse + (second (const NoUniqueness) <$> pScalarType pSize) + "ScalarType" + +instance IsString (ScalarTypeBase () NoUniqueness) where + fromString = + fromStringParse + (second (const NoUniqueness) <$> pScalarType (pure ())) + "ScalarType" + +instance IsString (TypeBase () NoUniqueness) where + fromString = + fromStringParse (second (const NoUniqueness) <$> pType (pure ())) "Type" instance IsString StructType where fromString = - fromStringParse (second (const NoUniqueness) <$> pType) "StructType" + fromStringParse (second (const NoUniqueness) <$> pType pSize) "StructType" instance IsString StructRetType where fromString = - fromStringParse (second (pure NoUniqueness) <$> pRetType) "StructRetType" + fromStringParse (second (pure NoUniqueness) <$> pRetType pSize) "StructRetType" instance IsString ResRetType where - fromString = fromStringParse pRetType "ResRetType" + fromString = fromStringParse (pRetType pSize) "ResRetType" diff --git a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs new file mode 100644 index 0000000000..327b6b0fb0 --- /dev/null +++ b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs @@ -0,0 +1,46 @@ +module Language.Futhark.TypeChecker.TySolveTests (tests) where + +import Data.Map qualified as M +import Futhark.Util.Pretty (docString) +import Language.Futhark.Syntax (Liftedness (..)) +import Language.Futhark.SyntaxTests () +import Language.Futhark.TypeChecker.Constraints + ( CtTy (..), + Reason (..), + TyParams, + TyVarInfo (..), + TyVars, + ) +import Language.Futhark.TypeChecker.Monad (prettyTypeError) +import Language.Futhark.TypeChecker.TySolve +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.HUnit (Assertion, assertFailure, testCase, (@?=)) + +testSolve :: + [CtTy ()] -> + TyParams -> + TyVars () -> + ([UnconTyVar], Solution) -> + Assertion +testSolve constraints typarams tyvars expected = + case solve constraints typarams tyvars of + Right s -> s @?= expected + Left e -> assertFailure $ docString $ prettyTypeError e + +-- When writing type variables/names here (a_0, b_1), make *sure* that +-- the numbers are distinct. These are all that actually matter for +-- determining identify. + +tests :: TestTree +tests = + testGroup + "Unsized constraint solver" + [ testCase "empty" $ + testSolve [] mempty mempty ([], mempty), + testCase "a_0 ~ b_1" $ + testSolve + [CtEq (Reason mempty) "a_0" "b_1"] + mempty + (M.fromList [("a_0", (0, TyVarFree mempty Unlifted))]) + ([], M.fromList [("a_0", Right "b_1")]) + ] diff --git a/unittests/Language/Futhark/TypeCheckerTests.hs b/unittests/Language/Futhark/TypeCheckerTests.hs index 02619fcec7..f83534c060 100644 --- a/unittests/Language/Futhark/TypeCheckerTests.hs +++ b/unittests/Language/Futhark/TypeCheckerTests.hs @@ -1,5 +1,6 @@ module Language.Futhark.TypeCheckerTests (tests) where +import Language.Futhark.TypeChecker.TySolveTests qualified import Language.Futhark.TypeChecker.TypesTests qualified import Test.Tasty @@ -7,5 +8,6 @@ tests :: TestTree tests = testGroup "Source type checker tests" - [ Language.Futhark.TypeChecker.TypesTests.tests + [ Language.Futhark.TypeChecker.TypesTests.tests, + Language.Futhark.TypeChecker.TySolveTests.tests ] From cf552f943f93175f0c7b47bfa1f99ebe6972a6f9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 6 Mar 2025 10:46:40 +0100 Subject: [PATCH 284/296] Slightly nicer. --- .../Language/Futhark/TypeChecker/TySolveTests.hs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs index 327b6b0fb0..426310f82d 100644 --- a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs +++ b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs @@ -2,10 +2,11 @@ module Language.Futhark.TypeChecker.TySolveTests (tests) where import Data.Map qualified as M import Futhark.Util.Pretty (docString) -import Language.Futhark.Syntax (Liftedness (..)) +import Language.Futhark.Syntax (Liftedness (..), NoUniqueness, TypeBase, VName) import Language.Futhark.SyntaxTests () import Language.Futhark.TypeChecker.Constraints ( CtTy (..), + Level, Reason (..), TyParams, TyVarInfo (..), @@ -29,7 +30,13 @@ testSolve constraints typarams tyvars expected = -- When writing type variables/names here (a_0, b_1), make *sure* that -- the numbers are distinct. These are all that actually matter for --- determining identify. +-- determining identity. + +(~) :: TypeBase () NoUniqueness -> TypeBase () NoUniqueness -> CtTy () +t1 ~ t2 = CtEq (Reason mempty) t1 t2 + +tv :: VName -> Level -> (VName, (Level, TyVarInfo ())) +tv v lvl = (v, (lvl, TyVarFree mempty Unlifted)) tests :: TestTree tests = @@ -39,8 +46,8 @@ tests = testSolve [] mempty mempty ([], mempty), testCase "a_0 ~ b_1" $ testSolve - [CtEq (Reason mempty) "a_0" "b_1"] + ["a_0" ~ "b_1"] mempty - (M.fromList [("a_0", (0, TyVarFree mempty Unlifted))]) + (M.fromList [tv "a_0" 0]) ([], M.fromList [("a_0", Right "b_1")]) ] From 4604a7d1b8b3b06ee041af6d57139fc22739903c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 7 Apr 2025 18:51:37 +0200 Subject: [PATCH 285/296] Add TySolve logging machinery. --- src/Language/Futhark/TypeChecker/TySolve.hs | 26 +++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 44b873f925..41077d5137 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -16,11 +16,13 @@ import Data.Loc import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S +import Debug.Trace +import Futhark.Util (isEnvVarAtLeast) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Error -import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote) +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote, prettyTypeError) import Language.Futhark.TypeChecker.Types (substTyVars) -- | The type representation used by the constraint solver. Agnostic @@ -661,11 +663,31 @@ solve :: TyVars () -> Either TypeError ([UnconTyVar], Solution) solve constraints typarams tyvars = - second solution + logProblem + . second solution . runExcept . flip execStateT (initialState typarams tyvars) . runSolveM $ do mapM_ solveCt constraints mapM_ solveTyVar (M.toList tyvars) + where + logProblem + | isEnvVarAtLeast "FUTHARK_LOG_TYSOLVE" 0 = \s -> + let msg = + unlines + [ "# TySolve.solve", + "## constraints", + show constraints, + "## typarams", + show typarams, + "## tyvars", + show tyvars, + either + (("## error\n" <>) . docString . prettyTypeError) + (("## solution\n" <>) . show) + s + ] + in trace msg s + | otherwise = id {-# NOINLINE solve #-} From 0816c0b77a9d00e2408bb6d91e853ae8a8d69604 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 7 Apr 2025 19:18:29 +0200 Subject: [PATCH 286/296] Add another test case. --- unittests/Language/Futhark/SyntaxTests.hs | 2 +- .../Futhark/TypeChecker/TySolveTests.hs | 24 +++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/unittests/Language/Futhark/SyntaxTests.hs b/unittests/Language/Futhark/SyntaxTests.hs index 61efc6306c..fdbb760791 100644 --- a/unittests/Language/Futhark/SyntaxTests.hs +++ b/unittests/Language/Futhark/SyntaxTests.hs @@ -51,7 +51,7 @@ instance Arbitrary PrimValue where instance IsString VName where fromString s = - let (s', '_' : tag) = span (/= '_') s + let (tag, s') = bimap reverse (reverse . tail) $ span (/= '_') $ reverse s in VName (fromString s') (read tag) instance (IsString v) => IsString (QualName v) where diff --git a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs index 426310f82d..4d1dafc231 100644 --- a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs +++ b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs @@ -2,7 +2,7 @@ module Language.Futhark.TypeChecker.TySolveTests (tests) where import Data.Map qualified as M import Futhark.Util.Pretty (docString) -import Language.Futhark.Syntax (Liftedness (..), NoUniqueness, TypeBase, VName) +import Language.Futhark.Syntax import Language.Futhark.SyntaxTests () import Language.Futhark.TypeChecker.Constraints ( CtTy (..), @@ -44,10 +44,30 @@ tests = "Unsized constraint solver" [ testCase "empty" $ testSolve [] mempty mempty ([], mempty), + -- testCase "a_0 ~ b_1" $ testSolve ["a_0" ~ "b_1"] mempty (M.fromList [tv "a_0" 0]) - ([], M.fromList [("a_0", Right "b_1")]) + ([], M.fromList [("a_0", Right "b_1")]), + -- + testCase "infer unlifted" $ + testSolve + [ CtEq (ReasonBranches noLoc "t_9895" "t_9896") "t_9895" "if_t_9897", + CtEq (ReasonBranches noLoc "t_9895" "t_9896") (Scalar "t_9896") "if_t_9897" + ] + mempty + ( M.fromList + [ ("t_9895", (2, TyVarFree noLoc Lifted)), + ("t_9896", (3, TyVarFree noLoc Lifted)), + ("if_t_9897", (4, TyVarFree noLoc SizeLifted)) + ] + ) + ( [("if_t_9897", SizeLifted)], + M.fromList + [ ("t_9895", Right "if_t_9897"), + ("t_9896", Right "if_t_9897") + ] + ) ] From 82abb1562c51d9af118dde71b473393a2d511d2b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 7 Apr 2025 19:21:18 +0200 Subject: [PATCH 287/296] More realistic test. --- .../Language/Futhark/TypeChecker/TySolveTests.hs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs index 4d1dafc231..e04cdd599f 100644 --- a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs +++ b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs @@ -55,19 +55,22 @@ tests = testCase "infer unlifted" $ testSolve [ CtEq (ReasonBranches noLoc "t_9895" "t_9896") "t_9895" "if_t_9897", - CtEq (ReasonBranches noLoc "t_9895" "t_9896") (Scalar "t_9896") "if_t_9897" + CtEq (ReasonBranches noLoc "t_9895" "t_9896") (Scalar "t_9896") "if_t_9897", + "if_t_9897" ~ "res_42" ] mempty ( M.fromList [ ("t_9895", (2, TyVarFree noLoc Lifted)), ("t_9896", (3, TyVarFree noLoc Lifted)), - ("if_t_9897", (4, TyVarFree noLoc SizeLifted)) + ("if_t_9897", (4, TyVarFree noLoc SizeLifted)), + ("res_42", (1, TyVarFree noLoc Lifted)) ] ) - ( [("if_t_9897", SizeLifted)], + ( [("res_42", SizeLifted)], M.fromList - [ ("t_9895", Right "if_t_9897"), - ("t_9896", Right "if_t_9897") + [ ("t_9895", Right "res_42"), + ("t_9896", Right "res_42"), + ("if_t_9897", Right "res_42") ] ) ] From cb3642ac89c789923d4b79c0c2b976c94c6d9396 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 8 Apr 2025 08:51:58 +0200 Subject: [PATCH 288/296] Nicer printing. --- futhark.cabal | 1 + src/Language/Futhark/TypeChecker/TySolve.hs | 55 +++++++++++++------ .../Futhark/TypeChecker/TySolveTests.hs | 23 ++++---- 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index 47b096c321..46097523f8 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -561,6 +561,7 @@ test-suite unit , free , futhark , megaparsec + , srcloc >=0.4 , tasty , tasty-hunit , tasty-quickcheck diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 41077d5137..7664ce3aed 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -655,6 +655,41 @@ solveTyVar (tv, (_, TyVarPrim loc pts)) = do "which is not possible." _ -> pure () +-- Print in a way helpful for writing a test case for TySolveTests. +logSolution :: + [CtTy ()] -> + TyParams -> + TyVars () -> + Either TypeError ([UnconTyVar], Solution) -> + String +logSolution constraints typarams tyvars s = + unlines $ + ["# TySolve.solve", "## constraints"] + <> map ppConstraint constraints + <> [ "## typarams", + if typarams == mempty then "mempty" else show $ map ppTyParam (M.toList typarams) + ] + <> [ "## tyvars", + show $ map (bimap prettyNameString (second onTyVar)) $ M.toList tyvars, + either + (("## error\n" <>) . docString . prettyTypeError) + ( ("## solution\n" <>) + . show + . bimap + (map (first prettyNameString)) + (map (bimap prettyNameString $ bimap prettyString prettyString) . M.toList) + ) + s + ] + where + ppConstraint (CtEq _ t1 t2) = + unwords [show (prettyString t1), "~", show (prettyString t2)] + ppTyParam (p, (lvl, info, _)) = show (prettyNameString p, (lvl, info, NoLoc)) + onTyVar (TyVarFree _ l) = TyVarFree NoLoc l + onTyVar (TyVarPrim _ pts) = TyVarPrim NoLoc pts + onTyVar (TyVarRecord _ ts) = TyVarRecord NoLoc ts + onTyVar (TyVarSum _ ts) = TyVarSum NoLoc ts + -- | Solve type constraints, producing either an error or a solution, -- alongside a list of unconstrained type variables. solve :: @@ -663,7 +698,7 @@ solve :: TyVars () -> Either TypeError ([UnconTyVar], Solution) solve constraints typarams tyvars = - logProblem + maybeLog . second solution . runExcept . flip execStateT (initialState typarams tyvars) @@ -672,22 +707,8 @@ solve constraints typarams tyvars = mapM_ solveCt constraints mapM_ solveTyVar (M.toList tyvars) where - logProblem + maybeLog | isEnvVarAtLeast "FUTHARK_LOG_TYSOLVE" 0 = \s -> - let msg = - unlines - [ "# TySolve.solve", - "## constraints", - show constraints, - "## typarams", - show typarams, - "## tyvars", - show tyvars, - either - (("## error\n" <>) . docString . prettyTypeError) - (("## solution\n" <>) . show) - s - ] - in trace msg s + trace (logSolution constraints typarams tyvars s) s | otherwise = id {-# NOINLINE solve #-} diff --git a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs index e04cdd599f..944a2dd233 100644 --- a/unittests/Language/Futhark/TypeChecker/TySolveTests.hs +++ b/unittests/Language/Futhark/TypeChecker/TySolveTests.hs @@ -1,5 +1,6 @@ module Language.Futhark.TypeChecker.TySolveTests (tests) where +import Data.Loc (Loc (NoLoc)) import Data.Map qualified as M import Futhark.Util.Pretty (docString) import Language.Futhark.Syntax @@ -54,23 +55,23 @@ tests = -- testCase "infer unlifted" $ testSolve - [ CtEq (ReasonBranches noLoc "t_9895" "t_9896") "t_9895" "if_t_9897", - CtEq (ReasonBranches noLoc "t_9895" "t_9896") (Scalar "t_9896") "if_t_9897", - "if_t_9897" ~ "res_42" + [ "t\8320_9896" ~ "if_t\8322_9898", + "t\8321_9897" ~ "if_t\8322_9898", + "t\8323_9899" ~ "if_t\8322_9898" ] mempty ( M.fromList - [ ("t_9895", (2, TyVarFree noLoc Lifted)), - ("t_9896", (3, TyVarFree noLoc Lifted)), - ("if_t_9897", (4, TyVarFree noLoc SizeLifted)), - ("res_42", (1, TyVarFree noLoc Lifted)) + [ ("t\8320_9896", (2, TyVarFree NoLoc Lifted)), + ("t\8321_9897", (3, TyVarFree NoLoc Lifted)), + ("if_t\8322_9898", (4, TyVarFree NoLoc SizeLifted)), + ("t\8323_9899", (5, TyVarFree NoLoc Lifted)) ] ) - ( [("res_42", SizeLifted)], + ( [("if_t\8322_9898", SizeLifted)], M.fromList - [ ("t_9895", Right "res_42"), - ("t_9896", Right "res_42"), - ("if_t_9897", Right "res_42") + [ ("t\8320_9896", Right "if_t\8322_9898"), + ("t\8321_9897", Right "if_t\8322_9898"), + ("t\8323_9899", Right "if_t\8322_9898") ] ) ] From 8dcde986dca0e58ad2a1fbe6c8dbb1eb2566b121 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 May 2025 13:47:39 +0200 Subject: [PATCH 289/296] Add compiler benchmark suite. --- .../Futhark/TypeChecker/TySolveBenchmarks.hs | 43 +++++++++++++++++++ benchmarks/README.md | 4 ++ benchmarks/futhark_benchmarks.hs | 10 +++++ futhark.cabal | 32 +++++++++++--- 4 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 benchmarks/Language/Futhark/TypeChecker/TySolveBenchmarks.hs create mode 100644 benchmarks/README.md create mode 100644 benchmarks/futhark_benchmarks.hs diff --git a/benchmarks/Language/Futhark/TypeChecker/TySolveBenchmarks.hs b/benchmarks/Language/Futhark/TypeChecker/TySolveBenchmarks.hs new file mode 100644 index 0000000000..c8dbdcc4d6 --- /dev/null +++ b/benchmarks/Language/Futhark/TypeChecker/TySolveBenchmarks.hs @@ -0,0 +1,43 @@ +module Language.Futhark.TypeChecker.TySolveBenchmarks (benchmarks) where + +import Criterion (Benchmark, bench, bgroup, whnf) +import Data.Map qualified as M +import Language.Futhark.Syntax +import Language.Futhark.SyntaxTests () +import Language.Futhark.TypeChecker.Constraints + ( CtTy (..), + Level, + Reason (..), + TyParams, + TyVarInfo (..), + TyVars, + ) +import Language.Futhark.TypeChecker.Monad (TypeError (..)) +import Language.Futhark.TypeChecker.TySolve (Solution, UnconTyVar, solve) + +(~) :: TypeBase () NoUniqueness -> TypeBase () NoUniqueness -> CtTy () +t1 ~ t2 = CtEq (Reason mempty) t1 t2 + +tv :: VName -> Level -> (VName, (Level, TyVarInfo ())) +tv v lvl = (v, (lvl, TyVarFree mempty Unlifted)) + +solve' :: + ( [CtTy ()], + TyParams, + TyVars () + ) -> + Either TypeError ([UnconTyVar], Solution) +solve' (constraints, typarams, tyvars) = solve constraints typarams tyvars + +benchmarks :: Benchmark +benchmarks = + bgroup + "TySolve" + [ bench "trivial" $ + whnf + solve' + ( ["a_0" ~ "b_1"], + mempty, + M.fromList [tv "a_0" 0] + ) + ] diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000000..6ec2cc1328 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,4 @@ +# Compiler benchmarks + +This directory contains benchmarks for the Futhark compiler itself. See +[../futhark-benchmarks][../futhark-benchmarks] for Futhark benchmark programs. diff --git a/benchmarks/futhark_benchmarks.hs b/benchmarks/futhark_benchmarks.hs new file mode 100644 index 0000000000..4ed7e8a053 --- /dev/null +++ b/benchmarks/futhark_benchmarks.hs @@ -0,0 +1,10 @@ +module Main (main) where + +import Criterion.Main +import Language.Futhark.TypeChecker.TySolveBenchmarks qualified + +main :: IO () +main = + defaultMain + [ Language.Futhark.TypeChecker.TySolveBenchmarks.benchmarks + ] diff --git a/futhark.cabal b/futhark.cabal index e48f20a141..79e5987ace 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -519,12 +519,10 @@ executable futhark ghc-options: -threaded -rtsopts "-with-rtsopts=-maxN16 -qg1 -A16M" build-depends: base, futhark -test-suite unit +library futhark-testing import: common - type: exitcode-stdio-1.0 - main-is: futhark_tests.hs hs-source-dirs: unittests - other-modules: + exposed-modules: Futhark.AD.DerivativesTests Futhark.Analysis.AlgSimplifyTests Futhark.Analysis.PrimExp.TableTests @@ -558,7 +556,7 @@ test-suite unit Language.Futhark.TypeCheckerTests build-depends: QuickCheck >=2.8 - , mtl >=2.2.1 + , mtl >=2.2.1 , base , containers , free @@ -570,3 +568,27 @@ test-suite unit , tasty-quickcheck , text , vector >=0.12 + +test-suite unit + import: common + type: exitcode-stdio-1.0 + main-is: unittests/futhark_tests.hs + build-depends: + base + , futhark-testing + , tasty + +benchmark benchmarks + import: common + type: exitcode-stdio-1.0 + main-is: futhark_benchmarks.hs + hs-source-dirs: benchmarks + other-modules: + Language.Futhark.TypeChecker.TySolveBenchmarks + + build-depends: + base + , containers + , criterion + , futhark + , futhark-testing From b3b8ee008a339f97661ca0189b412e2cb884b973 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 15 May 2025 14:17:53 +0200 Subject: [PATCH 290/296] Adhoc fixes. --- nix/glpk-hs.nix | 6 +++--- src/Language/Futhark/Interpreter/Values.hs | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nix/glpk-hs.nix b/nix/glpk-hs.nix index 6f5a2b0081..189135ed22 100644 --- a/nix/glpk-hs.nix +++ b/nix/glpk-hs.nix @@ -5,9 +5,9 @@ mkDerivation { pname = "glpk-hs"; version = "0.8"; src = fetchgit { - url = "https://github.com/ludat/glpk-hs.git"; - sha256 = "0nly5nifdb93f739vr3jzgi16fccqw5l0aabf5lglsdkdad713q1"; - rev = "efcb8354daa1205de2b862898353da2e4beb76b2"; + url = "https://github.com/jyp/glpk-hs.git"; + sha256 = "sha256-AY9wmmqzafpocUspQAvHjDkT4vty5J3GcSOt5qItnlo="; + rev = "1f276aa19861203ea8367dc27a6ad4c8a31c9062"; fetchSubmodules = true; }; isLibrary = true; diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 7d03dc8da0..0c2644c0bf 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -35,7 +35,6 @@ where import Control.Monad.Identity import Data.Array -import Data.Bifunctor (Bifunctor (second)) import Data.List (genericLength, genericReplicate) import Data.Map qualified as M import Data.Maybe From 427a22adf9e5a474df0524f2d502d933c368393c Mon Sep 17 00:00:00 2001 From: Jacob Siegumfeldt <114237410+jacobgummer@users.noreply.github.com> Date: Tue, 8 Jul 2025 15:53:25 +0200 Subject: [PATCH 291/296] An attempt at optimizing the `TySolve` module (#2298) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * A few tests * More unit tests * Remove dead code * New function for tests that should fail * Better version of 'testSolveFail' * Extra import * Function for making type params * New tests * Space between test cases * Rename test case * Extra occurs check test * Few more tests * Explicit import * Rename 'sub' function to 'normalize' * Clear everything * Add Union-Find data structure * Refactor * Add new version of TySolve * Return to original implementation (for now) * Import Control.Monad.ST * Fix comments * Add todo * Add import * Update todo * Refactor most of UnionFind module * Refactor SolveM and implement initialState * Explicit imports * Clarify comment * Add skeleton for new implementation * Always assign second argument node's key to result of union * Make some things more readable * Debug prints * Clarify comment * More information in debug prints * Add regex support * Use regexes on expected error messages * Minor cleanup * Add function for making type variables with constraints * Explicitly import '=~' for regexes * Add possible new representation of type variables * Add 'level' field to ReprInfo * Update union function to track minimum level in ReprInfo - It might be better to handle the level logic in TySolve but we'll have it like this for now * Renaming * Add 'lvl' as argument to 'makeTyVarNode' * Clarify comment * Improve debug output in subTyVar by listing type variables in trace message * Remove 'level' (for now) * Remove debug prints * Refactor 'union' to correctly handle which key new root gets * Implement 'unionTyVars' (with helper functions) * Use '$' instead * Rename 'assignType' to 'assignNewSol' * occursCheck working? * Implement 'solveEq' * subTyVar implemented? * Remove redundant functions * Rename variable * Remove unused Unifier type definition * Implement 'solve' * Add scope and liftedness checks * Fix some stuff * Use new implementation of TySolve for tests * Outcomment most tests for now * Minor fix * Add 'isRepr' function * Comment out unused function * Readd all tests * Minor fixes in 'solve' * Fix 'solution' * Fix 'scopeCheck' * Remove 'traceM' import * Mini refactor * Utilize that 'solution' still has access to the state * Style fix * Avoid converting to and from list * Use state instead of function parameter * Fix when type variable is linked to other type variable * Add test case * Remove old pattern * Fix indentation * Rename `tv` to `tvFree` * Remove unused function * Add functions for creating different types of flexible type variables * Rename variable * Traverse map instead of converting to and from list * Use state instead of formal parameter * Missing space * Rename `initialState` to `initializeState` * Add todo * Use our implementation of 'substTyVars' * Optimize 'solution' function * Remove redundant function * Change variable name * Obey line length limit here * Rename this function to clarify that it returns a tyvar's solution * Split 'mkSubst's type signature into multiple lines * Refactor duplicated code into function * Remove a line * Replace 'enrichType' with 'substTyVars' * Avoid redundant lookup * Small refactor * Add clarifying comment * Move this down * Clarify this comment * Remove this comment * Use function composition here * Use function composition here * Use these "specialized" functions instead * Rename variables for improved readability * Add todo * Mini refactor * Rename these for clarity * Fix todo * Add short documentation for `getKey` * Remove redundant function * Use '@' instead * Fix bug in `solveEq` * Add tests for record with polymorphic fields and opaque type handling * Add tests for propagation of liftedness * Add extra test that should result in a scope violation * Add logging possibility to new implementation * Consistency fix * Add '|' to make documentation visble when hovering over function names * Temporary fix of occurs check bug * New way of doing occurs check - Not sure if this is a complete fix but all tests still pass * Add todo * Add todo * Try to provide somewhat more useful error information * Add new tests for occurs check * Format for better readability * Small optimization of `occursCheck` * Remove outcommented code * Overwrite old type constraint solver with new implementation - Refactored the `UnionFind` module to include a new function `getLvl` for retrieving the level of a type variable node. - Updated the `union` function in `UnionFind` to properly handle the level of type variables during union operations. * Avoid normalizing in here, too * Rewrite `TyVarFree` case in `solveTyVar` * Remove comment * Remove todo * Remove dead code * Change the expected result - This is equivalent to the former substitution so it is still a correct solution * Change test case names * Fix bug when doing scope check * Minor change in type signature * Temporary fix of scope check * scopeCheck in solveTyVar to pass level1.fut test * Add benchmarking script * Remove dead code * Remove more dead code * Remove redundant function * Remove redundant import * Remove strict unpacking for solution and key fields * Refactor path compression logic in `find` * Update path compression logic (again) * Rename `solution` to `getSolution` * Refactor path compression logic in `find` * Substitute with the root type variable when original type variable isn't solved * Benchmarking * Refactor `SolveM` - Changed the stack order to improve performance - Changed from using `StateT` to `ReaderT` to make it more clear that we don't alter the map from type variables to their nodes * Add benchmarking script for Futhark performance comparison * Unnecessary * Fix comments * Remove information about `level`s inside representatives * Remove this * Remove to-do's * Benchs now include cases with less than 30 cons * Remove unused import * Add hash tables * Add extra check in `find` * Disable union-by-weight * Remove hash table imports * Use `$` instead * Flatten the node structure - `Repr` is no contained in a reference * Remove weight from information about representative * A little more clean * Remove strictness annotations * Remove `hashtables` dependency * Refactor `bindTyVar` to accept a node as a parameter * Refactor `solveEq` - Possibly return a node from `flexible` to avoid multiple lookups - Rename `sub` to `normalize` - Make `solveCt'` more readable in general * Refactor `lookupTyVar` to accept a `TyVarNode` * Refactor `unionTyVars` to accept `TyVarNode` parameters * Remove redundant check in `TyVarPrim` case in `solveTyVar` * Refactor `getSolution` - Resolve every type variable's type first, then find unconstrained and solutions * Use non-strict version of `foldrWithKey` - Performance is more or less the same so no need to introduce strictness * Consistency * 'union' is probably a better word here * Refactor `flexible` function to simplify return type and improve readability in `solveEq` * Add comment * Refactor `scopeCheck` * Simplify these cases to only update liftedness when necessary * Flip branches in if-statement * Split type annotation into multiple lines * Add clarifying comment * Move this comment * Add clarifying error message (just in case) * Change variable name from `n` to `node` * Change name of this variable * Split this type annotation into multiple lines * Split this case into multiple lines * Split this type annotation into multiple lines * Switch order of cases in `solveTyVar` to reflect order in `TyVarInfo` definition * Remove this colon * Remove profiling option * Remove redundant `do` * Rename `LinkInfo` to `NodeInfo` * Change variable names for clarity * Refactor `NodeInfo` structure - `ReprInfo` helps make operations more type safe - Strict fields help improve performance * Turn `get` functions into one-liners * Clarify `find` documentation * Clearer and more consistent naming of variables * A tiny bit more readable * Add `unionNewSol` function to explicitly pass new solution as argument * Refactor `unionTyVars` - Checking if unification is possible before joining equivalence classes - Using `unionNewSol` to explicitly choose new solution of resulting equivalence class * Put this on the same line * Clarify comment * Style fixes. * Fix futhark-benchmarks submodule * Remove generated type solver benchmarks. They are much too large and many of them fail. The idea is good, so maybe we can bring them back in a more controlled way someday. * Style fixes. * Fewer of these. * Restore a real benchmark. * Better grouping. * Reformat. --------- Co-authored-by: Laust K. Dengsøe Co-authored-by: Laust Kjæp Dengsøe <42211117+LaustDengsoe@users.noreply.github.com> Co-authored-by: Troels Henriksen --- benchmarks/README.md | 4 - benchmarks/futhark_benchmarks.hs | 10 - futhark-benchmarks | 2 +- futhark.cabal | 6 + src-testing/Generated/AllFutBenchmarks.hs | 24 + .../Accelerate/Nbody/Nbodybh.hs | 1534 +++++++++++++++++ .../Futhark/TypeChecker/TySolveBenchmarks.hs | 74 +- .../Futhark/TypeChecker/TySolveTests.hs | 368 +++- src-testing/futhark_benchmarks.hs | 4 +- src/Language/Futhark/TypeChecker/TySolve.hs | 977 ++++++----- .../Futhark/TypeChecker/TySolveOld.hs | 703 ++++++++ src/Language/Futhark/TypeChecker/UnionFind.hs | 137 ++ 12 files changed, 3371 insertions(+), 472 deletions(-) delete mode 100644 benchmarks/README.md delete mode 100644 benchmarks/futhark_benchmarks.hs create mode 100644 src-testing/Generated/AllFutBenchmarks.hs create mode 100644 src-testing/Generated/AllFutBenchmarks/Accelerate/Nbody/Nbodybh.hs create mode 100644 src/Language/Futhark/TypeChecker/TySolveOld.hs create mode 100644 src/Language/Futhark/TypeChecker/UnionFind.hs diff --git a/benchmarks/README.md b/benchmarks/README.md deleted file mode 100644 index 6ec2cc1328..0000000000 --- a/benchmarks/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# Compiler benchmarks - -This directory contains benchmarks for the Futhark compiler itself. See -[../futhark-benchmarks][../futhark-benchmarks] for Futhark benchmark programs. diff --git a/benchmarks/futhark_benchmarks.hs b/benchmarks/futhark_benchmarks.hs deleted file mode 100644 index 4ed7e8a053..0000000000 --- a/benchmarks/futhark_benchmarks.hs +++ /dev/null @@ -1,10 +0,0 @@ -module Main (main) where - -import Criterion.Main -import Language.Futhark.TypeChecker.TySolveBenchmarks qualified - -main :: IO () -main = - defaultMain - [ Language.Futhark.TypeChecker.TySolveBenchmarks.benchmarks - ] diff --git a/futhark-benchmarks b/futhark-benchmarks index 13d3cb5cb2..0d427b9483 160000 --- a/futhark-benchmarks +++ b/futhark-benchmarks @@ -1 +1 @@ -Subproject commit 13d3cb5cb2c887adca2bf4fbd02f9e866436cbfe +Subproject commit 0d427b94838beea4d6512f7639860e5b967ce7bc diff --git a/futhark.cabal b/futhark.cabal index 2004a1c905..d72eb9a204 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -429,6 +429,8 @@ library Language.Futhark.TypeChecker.Consumption Language.Futhark.TypeChecker.Constraints Language.Futhark.TypeChecker.TySolve + Language.Futhark.TypeChecker.TySolveOld + Language.Futhark.TypeChecker.UnionFind Language.Futhark.TypeChecker.Error Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match @@ -523,6 +525,8 @@ library futhark-testing import: common hs-source-dirs: src-testing exposed-modules: + Generated.AllFutBenchmarks + Generated.AllFutBenchmarks.Accelerate.Nbody.Nbodybh Futhark.AD.DerivativesTests Futhark.Analysis.AlgSimplifyTests Futhark.Analysis.PrimExp.TableTests @@ -572,6 +576,8 @@ library futhark-testing , tasty-quickcheck , text , vector >=0.12 + , srcloc + , regex-tdfa ^>= 1.3.2 test-suite unit import: common diff --git a/src-testing/Generated/AllFutBenchmarks.hs b/src-testing/Generated/AllFutBenchmarks.hs new file mode 100644 index 0000000000..b3f0c736e0 --- /dev/null +++ b/src-testing/Generated/AllFutBenchmarks.hs @@ -0,0 +1,24 @@ +module Generated.AllFutBenchmarks + ( allFutBenchmarkCases, + BenchmarkCaseData, + ) +where + +import Generated.AllFutBenchmarks.Accelerate.Nbody.Nbodybh qualified as AllFutBenchmarksAccelerateNbodyNbodybh +import Language.Futhark.TypeChecker.Constraints (CtTy, TyParams, TyVars) + +type BenchmarkCaseData = ([CtTy ()], TyParams, TyVars ()) + +allFutBenchmarkCases :: [(String, BenchmarkCaseData)] +allFutBenchmarkCases = + [ ("accelerate/nbody/nbody-bh.fut (Block 1/10) (Cons: 121)", head AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList), + ("accelerate/nbody/nbody-bh.fut (Block 2/10) (Cons: 146)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 1), + ("accelerate/nbody/nbody-bh.fut (Block 3/10) (Cons: 133)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 2), + ("accelerate/nbody/nbody-bh.fut (Block 4/10) (Cons: 45)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 3), + ("accelerate/nbody/nbody-bh.fut (Block 5/10) (Cons: 210)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 4), + ("accelerate/nbody/nbody-bh.fut (Block 6/10) (Cons: 401)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 5), + ("accelerate/nbody/nbody-bh.fut (Block 7/10) (Cons: 39)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 6), + ("accelerate/nbody/nbody-bh.fut (Block 8/10) (Cons: 173)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 7), + ("accelerate/nbody/nbody-bh.fut (Block 9/10) (Cons: 164)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 8), + ("accelerate/nbody/nbody-bh.fut (Block 10/10) (Cons: 40)", AllFutBenchmarksAccelerateNbodyNbodybh.benchmarkDataList !! 9) + ] diff --git a/src-testing/Generated/AllFutBenchmarks/Accelerate/Nbody/Nbodybh.hs b/src-testing/Generated/AllFutBenchmarks/Accelerate/Nbody/Nbodybh.hs new file mode 100644 index 0000000000..7b23719cba --- /dev/null +++ b/src-testing/Generated/AllFutBenchmarks/Accelerate/Nbody/Nbodybh.hs @@ -0,0 +1,1534 @@ +module Generated.AllFutBenchmarks.Accelerate.Nbody.Nbodybh (benchmarkDataList) where + +import Data.Map qualified as M +import Futhark.Util.Loc (Loc (NoLoc)) +import Language.Futhark.Syntax +import Language.Futhark.SyntaxTests () +import Language.Futhark.TypeChecker.Constraints + ( CtTy (..), + Reason (..), + TyParams, + TyVarInfo (..), + TyVars, + ) + +(~) :: TypeBase () NoUniqueness -> TypeBase () NoUniqueness -> CtTy () +t1 ~ t2 = CtEq (Reason mempty) t1 t2 + +type BenchmarkCaseData = ([CtTy ()], TyParams, TyVars ()) + +benchmarkDataList :: [BenchmarkCaseData] +benchmarkDataList = + [ ( [ "t_8322_8328_8326" ~ "[]t_8322_8328_8326_8322_8329_8327", + "t_8322_8328_8328" ~ "[]t_8322_8328_8328_8322_8329_8328", + "t_8321_8323_8326" ~ "[]t_8321_8323_8326_8322_8329_8329", + "t_8327_8327" ~ "[]t_8327_8327_8323_8320_8320", + "t_8321_8322_8321" ~ "[]t_8321_8322_8321_8323_8320_8321", + "b_8327_8329" ~ "[]b_8327_8329_8323_8320_8322", + "a_8327_8328" ~ "[]a_8327_8328_8323_8320_8323", + "b_8326_8328" ~ "[]b_8326_8328_8323_8320_8324", + "a_8326_8327" ~ "[]a_8326_8327_8323_8320_8325", + "i32" ~ "t_8323", + "num_8324" ~ "t_8323", + "t_8323" ~ "i32", + "t_8320" ~ "t_1", + "i32" ~ "t_8322", + "num_8321_8323" ~ "t_8322", + "i32" ~ "i32", + "t_8320" ~ "t_1", + "t_8322" ~ "t_8321", + "i32" ~ "t_8321", + "t_8322_8326" ~ "arg_8323_8325 -> res_8323_8326", + "t_8322_8327" ~ "arg_8323_8325", + "res_8323_8326" ~ "arg_8323_8329 -> res_8324_8320", + "t_8323_8321" ~ "arg_8323_8329", + "t_8322_8326" ~ "arg_8324_8323 -> res_8324_8324", + "t_8322_8328" ~ "arg_8324_8323", + "res_8324_8324" ~ "arg_8324_8327 -> res_8324_8328", + "t_8323_8322" ~ "arg_8324_8327", + "t_8322_8326" ~ "arg_8325_8321 -> res_8325_8322", + "t_8322_8329" ~ "arg_8325_8321", + "res_8325_8322" ~ "arg_8325_8325 -> res_8325_8326", + "t_8323_8323" ~ "arg_8325_8325", + "t_8322_8326" ~ "arg_8325_8329 -> res_8326_8320", + "t_8323_8320" ~ "arg_8325_8329", + "res_8326_8320" ~ "arg_8326_8323 -> res_8326_8324", + "t_8323_8324" ~ "arg_8326_8323", + "{x: t_8320} -> t_8321" ~ "a_8326_8329 -> x_8327_8320", + "[]t_1" ~ "a_8326_8327", + "{as: []a_8326_8329} -> *[]x_8327_8320" ~ "a_8326_8327 -> b_8326_8328", + "t_8327_8327" ~ "b_8326_8328", + "t_8328_8322" ~ "t_8328_8323", + "num_8328_8324" ~ "t_8328_8323", + "bool" ~ "bool", + "t_8328_8322" ~ "t_8329_8321", + "num_8329_8322" ~ "t_8329_8321", + "bool" ~ "bool", + "t_8328_8322" ~ "t_8329_8329", + "num_8321_8320_8320" ~ "t_8329_8329", + "bool" ~ "bool", + "t_8328_8322" ~ "t_8321_8320_8327", + "num_8321_8320_8328" ~ "t_8321_8320_8327", + "bool" ~ "bool", + "{x: t_8328_8322} -> (i64, i64, i64, i64)" ~ "a_8328_8320 -> x_8328_8321", + "t_8327_8327" ~ "a_8327_8328", + "{as: []a_8328_8320} -> *[]x_8328_8321" ~ "a_8327_8328 -> b_8327_8329", + "t_8321_8322_8321" ~ "b_8327_8329", + "t_8321_8322_8323 -> t_8321_8322_8323 -> t_8321_8322_8323" ~ "t_8322_8326", + "(t_8322_8327, t_8322_8328, t_8322_8329, t_8323_8320) -> (t_8323_8321, t_8323_8322, t_8323_8323, t_8323_8324) -> (res_8324_8320, res_8324_8328, res_8325_8326, res_8326_8324)" ~ "a_8321_8322_8322 -> a_8321_8322_8322 -> a_8321_8322_8322", + "(num_8321_8322_8326, num_8321_8322_8327, num_8321_8322_8328, num_8321_8322_8329)" ~ "a_8321_8322_8322", + "t_8321_8322_8321" ~ "[]a_8321_8322_8322", + "t_8321_8323_8326" ~ "[]a_8321_8322_8322", + "t_8321_8323_8326" ~ "[]t_8321_8323_8327", + "(t_8321_8324_8320, t_8321_8324_8321, t_8321_8324_8322, t_8321_8324_8323)" ~ "t_8321_8323_8327", + "t_8321_8325_8327" ~ "num_8321_8325_8326", + "t_8321_8324_8324" ~ "t_8321_8325_8329", + "num_8321_8326_8320" ~ "t_8321_8325_8329", + "bool" ~ "bool", + "t_8321_8324_8325" ~ "t_8321_8325_8328", + "i64" ~ "t_8321_8325_8328", + "num_8321_8325_8326" ~ "t_8321_8325_8325", + "t_8321_8325_8328" ~ "t_8321_8325_8325", + "t_8321_8324_8324" ~ "t_8321_8327_8326", + "num_8321_8327_8327" ~ "t_8321_8327_8326", + "bool" ~ "bool", + "t_8321_8324_8320" ~ "t_8321_8327_8325", + "i64" ~ "t_8321_8327_8325", + "t_8321_8325_8325" ~ "t_8321_8325_8324", + "t_8321_8327_8325" ~ "t_8321_8325_8324", + "t_8321_8324_8324" ~ "t_8321_8329_8323", + "num_8321_8329_8324" ~ "t_8321_8329_8323", + "bool" ~ "bool", + "t_8321_8324_8326" ~ "t_8321_8329_8322", + "i64" ~ "t_8321_8329_8322", + "t_8321_8325_8324" ~ "t_8321_8325_8323", + "t_8321_8329_8322" ~ "t_8321_8325_8323", + "t_8321_8324_8324" ~ "t_8322_8321_8320", + "num_8322_8321_8321" ~ "t_8322_8321_8320", + "bool" ~ "bool", + "t_8321_8324_8321" ~ "t_8322_8320_8329", + "i64" ~ "t_8322_8320_8329", + "t_8321_8325_8323" ~ "t_8321_8325_8322", + "t_8322_8320_8329" ~ "t_8321_8325_8322", + "t_8321_8324_8324" ~ "t_8322_8322_8327", + "num_8322_8322_8328" ~ "t_8322_8322_8327", + "bool" ~ "bool", + "t_8321_8324_8327" ~ "t_8322_8322_8326", + "i64" ~ "t_8322_8322_8326", + "t_8321_8325_8322" ~ "t_8321_8325_8321", + "t_8322_8322_8326" ~ "t_8321_8325_8321", + "t_8321_8324_8324" ~ "t_8322_8324_8324", + "num_8322_8324_8325" ~ "t_8322_8324_8324", + "bool" ~ "bool", + "t_8321_8324_8322" ~ "t_8322_8324_8323", + "i64" ~ "t_8322_8324_8323", + "t_8321_8325_8321" ~ "t_8321_8325_8320", + "t_8322_8324_8323" ~ "t_8321_8325_8320", + "t_8321_8324_8324" ~ "t_8322_8326_8321", + "num_8322_8326_8322" ~ "t_8322_8326_8321", + "bool" ~ "bool", + "t_8321_8324_8328" ~ "t_8322_8326_8320", + "i64" ~ "t_8322_8326_8320", + "t_8321_8325_8320" ~ "t_8321_8324_8329", + "t_8322_8326_8320" ~ "t_8321_8324_8329", + "{bin: t_8321_8324_8324} -> (t_8321_8324_8325, t_8321_8324_8326, t_8321_8324_8327, t_8321_8324_8328) -> t_8321_8324_8329" ~ "a_8322_8327_8327 -> b_8322_8327_8328 -> x_8322_8327_8329", + "t_8327_8327" ~ "[]a_8322_8327_8327", + "t_8321_8323_8326" ~ "[]b_8322_8327_8328", + "t_8322_8328_8326" ~ "[]x_8322_8327_8329", + "[]t_1" ~ "t_8322_8328_8328", + "t_8322_8328_8328" ~ "[]t_8322_8328_8327", + "t_8322_8328_8326" ~ "[]i64", + "[]t_1" ~ "[]t_8322_8328_8327", + "[]t_1" ~ "[]t_8322_8328_8327" + ], + M.fromList [("t_1", (0, Unlifted, NoLoc))], + M.fromList [("t_8320", (5, TyVarFree NoLoc Lifted)), ("t_8321", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8323", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8326", (5, TyVarFree NoLoc Lifted)), ("t_8322_8327", (6, TyVarFree NoLoc Lifted)), ("t_8322_8328", (6, TyVarFree NoLoc Lifted)), ("t_8322_8329", (6, TyVarFree NoLoc Lifted)), ("t_8323_8320", (6, TyVarFree NoLoc Lifted)), ("t_8323_8321", (7, TyVarFree NoLoc Lifted)), ("t_8323_8322", (7, TyVarFree NoLoc Lifted)), ("t_8323_8323", (7, TyVarFree NoLoc Lifted)), ("t_8323_8324", (7, TyVarFree NoLoc Lifted)), ("arg_8323_8325", (8, TyVarFree NoLoc Lifted)), ("res_8323_8326", (8, TyVarFree NoLoc Lifted)), ("arg_8323_8329", (8, TyVarFree NoLoc Lifted)), ("res_8324_8320", (8, TyVarFree NoLoc Lifted)), ("arg_8324_8323", (8, TyVarFree NoLoc Lifted)), ("res_8324_8324", (8, TyVarFree NoLoc Lifted)), ("arg_8324_8327", (8, TyVarFree NoLoc Lifted)), ("res_8324_8328", (8, TyVarFree NoLoc Lifted)), ("arg_8325_8321", (8, TyVarFree NoLoc Lifted)), ("res_8325_8322", (8, TyVarFree NoLoc Lifted)), ("arg_8325_8325", (8, TyVarFree NoLoc Lifted)), ("res_8325_8326", (8, TyVarFree NoLoc Lifted)), ("arg_8325_8329", (8, TyVarFree NoLoc Lifted)), ("res_8326_8320", (8, TyVarFree NoLoc Lifted)), ("arg_8326_8323", (8, TyVarFree NoLoc Lifted)), ("res_8326_8324", (8, TyVarFree NoLoc Lifted)), ("a_8326_8327", (4, TyVarFree NoLoc Lifted)), ("b_8326_8328", (4, TyVarFree NoLoc Lifted)), ("a_8326_8329", (4, TyVarFree NoLoc Unlifted)), ("x_8327_8320", (4, TyVarFree NoLoc Unlifted)), ("t_8327_8327", (5, TyVarFree NoLoc Lifted)), ("a_8327_8328", (6, TyVarFree NoLoc Lifted)), ("b_8327_8329", (6, TyVarFree NoLoc Lifted)), ("a_8328_8320", (6, TyVarFree NoLoc Unlifted)), ("x_8328_8321", (6, TyVarFree NoLoc Unlifted)), ("t_8328_8322", (7, TyVarFree NoLoc Lifted)), ("t_8328_8323", (8, TyVarFree NoLoc Unlifted)), ("num_8328_8324", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8321", (8, TyVarFree NoLoc Unlifted)), ("num_8329_8322", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8329", (8, TyVarFree NoLoc Unlifted)), ("num_8321_8320_8320", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8327", (8, TyVarFree NoLoc Unlifted)), ("num_8321_8320_8328", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8321", (7, TyVarFree NoLoc Lifted)), ("a_8321_8322_8322", (8, TyVarFree NoLoc Unlifted)), ("t_8321_8322_8323", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8326", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8327", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8328", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8329", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8326", (9, TyVarFree NoLoc Lifted)), ("t_8321_8323_8327", (10, TyVarFree NoLoc Unlifted)), ("t_8321_8324_8320", (11, TyVarFree NoLoc Lifted)), ("t_8321_8324_8321", (11, TyVarFree NoLoc Lifted)), ("t_8321_8324_8322", (11, TyVarFree NoLoc Lifted)), ("t_8321_8324_8323", (11, TyVarFree NoLoc Lifted)), ("t_8321_8324_8324", (13, TyVarFree NoLoc Lifted)), ("t_8321_8324_8325", (14, TyVarFree NoLoc Lifted)), ("t_8321_8324_8326", (14, TyVarFree NoLoc Lifted)), ("t_8321_8324_8327", (14, TyVarFree NoLoc Lifted)), ("t_8321_8324_8328", (14, TyVarFree NoLoc Lifted)), ("t_8321_8324_8329", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8320", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8321", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8322", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8323", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8324", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8325", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8325_8326", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8327", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8328", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8329", (15, TyVarFree NoLoc Unlifted)), ("num_8321_8326_8320", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8325", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8326", (15, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8327_8327", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8322", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8323", (15, TyVarFree NoLoc Unlifted)), ("num_8321_8329_8324", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8329", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8321_8320", (15, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8321_8321", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322_8326", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322_8327", (15, TyVarFree NoLoc Unlifted)), ("num_8322_8322_8328", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8323", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8324", (15, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8324_8325", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8326_8320", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8326_8321", (15, TyVarFree NoLoc Unlifted)), ("num_8322_8326_8322", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("a_8322_8327_8327", (12, TyVarFree NoLoc Unlifted)), ("b_8322_8327_8328", (12, TyVarFree NoLoc Unlifted)), ("x_8322_8327_8329", (12, TyVarFree NoLoc Unlifted)), ("t_8322_8328_8326", (13, TyVarFree NoLoc Lifted)), ("t_8322_8328_8327", (14, TyVarFree NoLoc Unlifted)), ("t_8322_8328_8328", (14, TyVarFree NoLoc Unlifted)), ("t_8322_8328_8326_8322_8329_8327", (13, TyVarFree NoLoc Lifted)), ("t_8322_8328_8328_8322_8329_8328", (14, TyVarFree NoLoc Unlifted)), ("t_8321_8323_8326_8322_8329_8329", (9, TyVarFree NoLoc Lifted)), ("t_8327_8327_8323_8320_8320", (5, TyVarFree NoLoc Lifted)), ("t_8321_8322_8321_8323_8320_8321", (7, TyVarFree NoLoc Lifted)), ("b_8327_8329_8323_8320_8322", (6, TyVarFree NoLoc Lifted)), ("a_8327_8328_8323_8320_8323", (6, TyVarFree NoLoc Lifted)), ("b_8326_8328_8323_8320_8324", (4, TyVarFree NoLoc Lifted)), ("a_8326_8327_8323_8320_8325", (4, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8323_8320_8321" ~ "[]t_8323_8320_8321_8323_8323_8328", + "t_8323_8320_8323" ~ "[]t_8323_8320_8323_8323_8323_8329", + "t_8321_8323_8328" ~ "[]t_8321_8323_8328_8323_8324_8320", + "t_8327_8329" ~ "[]t_8327_8329_8323_8324_8321", + "t_8321_8322_8323" ~ "[]t_8321_8322_8323_8323_8324_8322", + "b_8328_8321" ~ "[]b_8328_8321_8323_8324_8323", + "a_8328_8320" ~ "[]a_8328_8320_8323_8324_8324", + "b_8327_8320" ~ "[]b_8327_8320_8323_8324_8325", + "a_8326_8329" ~ "[]a_8326_8329_8323_8324_8326", + "i32" ~ "t_8323", + "num_8324" ~ "t_8323", + "t_8323" ~ "i32", + "t_8320" ~ "t_1", + "i32" ~ "t_8322", + "num_8321_8323" ~ "t_8322", + "i32" ~ "i32", + "t_8320" ~ "t_1", + "t_8322" ~ "t_8321", + "i32" ~ "t_8321", + "t_8321" ~ "i32", + "t_8322_8328" ~ "arg_8323_8327 -> res_8323_8328", + "t_8322_8329" ~ "arg_8323_8327", + "res_8323_8328" ~ "arg_8324_8321 -> res_8324_8322", + "t_8323_8323" ~ "arg_8324_8321", + "t_8322_8328" ~ "arg_8324_8325 -> res_8324_8326", + "t_8323_8320" ~ "arg_8324_8325", + "res_8324_8326" ~ "arg_8324_8329 -> res_8325_8320", + "t_8323_8324" ~ "arg_8324_8329", + "t_8322_8328" ~ "arg_8325_8323 -> res_8325_8324", + "t_8323_8321" ~ "arg_8325_8323", + "res_8325_8324" ~ "arg_8325_8327 -> res_8325_8328", + "t_8323_8325" ~ "arg_8325_8327", + "t_8322_8328" ~ "arg_8326_8321 -> res_8326_8322", + "t_8323_8322" ~ "arg_8326_8321", + "res_8326_8322" ~ "arg_8326_8325 -> res_8326_8326", + "t_8323_8326" ~ "arg_8326_8325", + "{x: t_8320} -> i16" ~ "a_8327_8321 -> x_8327_8322", + "[]t_1" ~ "a_8326_8329", + "{as: []a_8327_8321} -> *[]x_8327_8322" ~ "a_8326_8329 -> b_8327_8320", + "t_8327_8329" ~ "b_8327_8320", + "t_8328_8324" ~ "t_8328_8325", + "num_8328_8326" ~ "t_8328_8325", + "bool" ~ "bool", + "t_8328_8324" ~ "t_8329_8323", + "num_8329_8324" ~ "t_8329_8323", + "bool" ~ "bool", + "t_8328_8324" ~ "t_8321_8320_8321", + "num_8321_8320_8322" ~ "t_8321_8320_8321", + "bool" ~ "bool", + "t_8328_8324" ~ "t_8321_8320_8329", + "num_8321_8321_8320" ~ "t_8321_8320_8329", + "bool" ~ "bool", + "{x: t_8328_8324} -> (i16, i16, i16, i16)" ~ "a_8328_8322 -> x_8328_8323", + "t_8327_8329" ~ "a_8328_8320", + "{as: []a_8328_8322} -> *[]x_8328_8323" ~ "a_8328_8320 -> b_8328_8321", + "t_8321_8322_8323" ~ "b_8328_8321", + "t_8321_8322_8325 -> t_8321_8322_8325 -> t_8321_8322_8325" ~ "t_8322_8328", + "(t_8322_8329, t_8323_8320, t_8323_8321, t_8323_8322) -> (t_8323_8323, t_8323_8324, t_8323_8325, t_8323_8326) -> (res_8324_8322, res_8325_8320, res_8325_8328, res_8326_8326)" ~ "a_8321_8322_8324 -> a_8321_8322_8324 -> a_8321_8322_8324", + "(num_8321_8322_8328, num_8321_8322_8329, num_8321_8323_8320, num_8321_8323_8321)" ~ "a_8321_8322_8324", + "t_8321_8322_8323" ~ "[]a_8321_8322_8324", + "t_8321_8323_8328" ~ "[]a_8321_8322_8324", + "i64" ~ "t_8321_8323_8329", + "num_8321_8324_8320" ~ "t_8321_8323_8329", + "t_8321_8323_8328" ~ "[]t_8321_8324_8329", + "bool" ~ "bool", + "(num_8321_8324_8325, num_8321_8324_8326, num_8321_8324_8327, num_8321_8324_8328)" ~ "if_t_8321_8325_8322", + "t_8321_8324_8329" ~ "if_t_8321_8325_8322", + "(t_8321_8325_8323, t_8321_8325_8324, t_8321_8325_8325, t_8321_8325_8326)" ~ "if_t_8321_8325_8322", + "t_8321_8327_8320" ~ "num_8321_8326_8329", + "t_8321_8325_8327" ~ "t_8321_8327_8322", + "num_8321_8327_8323" ~ "t_8321_8327_8322", + "bool" ~ "bool", + "t_8321_8325_8328" ~ "t_8321_8327_8321", + "i16" ~ "t_8321_8327_8321", + "num_8321_8326_8329" ~ "t_8321_8326_8328", + "t_8321_8327_8321" ~ "t_8321_8326_8328", + "t_8321_8325_8327" ~ "t_8321_8328_8329", + "num_8321_8329_8320" ~ "t_8321_8328_8329", + "bool" ~ "bool", + "t_8321_8325_8323" ~ "t_8321_8328_8328", + "i16" ~ "t_8321_8328_8328", + "t_8321_8326_8328" ~ "t_8321_8326_8327", + "t_8321_8328_8328" ~ "t_8321_8326_8327", + "t_8321_8325_8327" ~ "t_8322_8320_8326", + "num_8322_8320_8327" ~ "t_8322_8320_8326", + "bool" ~ "bool", + "t_8321_8325_8329" ~ "t_8322_8320_8325", + "i16" ~ "t_8322_8320_8325", + "t_8321_8326_8327" ~ "t_8321_8326_8326", + "t_8322_8320_8325" ~ "t_8321_8326_8326", + "t_8321_8325_8327" ~ "t_8322_8322_8323", + "num_8322_8322_8324" ~ "t_8322_8322_8323", + "bool" ~ "bool", + "t_8321_8325_8324" ~ "t_8322_8322_8322", + "i16" ~ "t_8322_8322_8322", + "t_8321_8326_8326" ~ "t_8321_8326_8325", + "t_8322_8322_8322" ~ "t_8321_8326_8325", + "t_8321_8325_8327" ~ "t_8322_8324_8320", + "num_8322_8324_8321" ~ "t_8322_8324_8320", + "bool" ~ "bool", + "t_8321_8326_8320" ~ "t_8322_8323_8329", + "i16" ~ "t_8322_8323_8329", + "t_8321_8326_8325" ~ "t_8321_8326_8324", + "t_8322_8323_8329" ~ "t_8321_8326_8324", + "t_8321_8325_8327" ~ "t_8322_8325_8327", + "num_8322_8325_8328" ~ "t_8322_8325_8327", + "bool" ~ "bool", + "t_8321_8325_8325" ~ "t_8322_8325_8326", + "i16" ~ "t_8322_8325_8326", + "t_8321_8326_8324" ~ "t_8321_8326_8323", + "t_8322_8325_8326" ~ "t_8321_8326_8323", + "t_8321_8325_8327" ~ "t_8322_8327_8324", + "num_8322_8327_8325" ~ "t_8322_8327_8324", + "bool" ~ "bool", + "t_8321_8326_8321" ~ "t_8322_8327_8323", + "i16" ~ "t_8322_8327_8323", + "t_8321_8326_8323" ~ "t_8321_8326_8322", + "t_8322_8327_8323" ~ "t_8321_8326_8322", + "t_8321_8326_8322" ~ "i16", + "{bin: t_8321_8325_8327} -> (t_8321_8325_8328, t_8321_8325_8329, t_8321_8326_8320, t_8321_8326_8321) -> i64" ~ "a_8322_8329_8322 -> b_8322_8329_8323 -> x_8322_8329_8324", + "t_8327_8329" ~ "[]a_8322_8329_8322", + "t_8321_8323_8328" ~ "[]b_8322_8329_8323", + "t_8323_8320_8321" ~ "[]x_8322_8329_8324", + "[]t_1" ~ "t_8323_8320_8323", + "t_8323_8320_8323" ~ "[]t_8323_8320_8322", + "t_8323_8320_8321" ~ "[]i64", + "[]t_1" ~ "[]t_8323_8320_8322", + "t_8321_8325_8323" ~ "et_8323_8321_8324", + "t_8321_8325_8324" ~ "et_8323_8321_8324", + "t_8321_8325_8325" ~ "et_8323_8321_8324", + "t_8321_8325_8326" ~ "et_8323_8321_8324", + "i16 -> i64" ~ "a_8323_8321_8322 -> x_8323_8321_8323", + "[]et_8323_8321_8324" ~ "[]a_8323_8321_8322", + "num_8323_8322_8320" ~ "et_8323_8321_8329", + "t_8321_8325_8323" ~ "et_8323_8321_8329", + "t_8321_8325_8323" ~ "t_8323_8322_8321", + "t_8321_8325_8324" ~ "t_8323_8322_8321", + "t_8323_8322_8321" ~ "et_8323_8321_8329", + "t_8321_8325_8323" ~ "t_8323_8322_8327", + "t_8321_8325_8324" ~ "t_8323_8322_8327", + "t_8323_8322_8327" ~ "t_8323_8322_8326", + "t_8321_8325_8325" ~ "t_8323_8322_8326", + "t_8323_8322_8326" ~ "et_8323_8321_8329", + "num_8323_8323_8326" ~ "i64", + "num_8323_8323_8327" ~ "i64", + "([]t_1, []i64, []i16)" ~ "([]t_8323_8320_8322, []x_8323_8321_8323, []et_8323_8321_8329)" + ], + M.fromList [("t_1", (0, Unlifted, NoLoc))], + M.fromList [("t_8320", (5, TyVarFree NoLoc Lifted)), ("t_8321", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8323", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8328", (5, TyVarFree NoLoc Lifted)), ("t_8322_8329", (6, TyVarFree NoLoc Lifted)), ("t_8323_8320", (6, TyVarFree NoLoc Lifted)), ("t_8323_8321", (6, TyVarFree NoLoc Lifted)), ("t_8323_8322", (6, TyVarFree NoLoc Lifted)), ("t_8323_8323", (7, TyVarFree NoLoc Lifted)), ("t_8323_8324", (7, TyVarFree NoLoc Lifted)), ("t_8323_8325", (7, TyVarFree NoLoc Lifted)), ("t_8323_8326", (7, TyVarFree NoLoc Lifted)), ("arg_8323_8327", (8, TyVarFree NoLoc Lifted)), ("res_8323_8328", (8, TyVarFree NoLoc Lifted)), ("arg_8324_8321", (8, TyVarFree NoLoc Lifted)), ("res_8324_8322", (8, TyVarFree NoLoc Lifted)), ("arg_8324_8325", (8, TyVarFree NoLoc Lifted)), ("res_8324_8326", (8, TyVarFree NoLoc Lifted)), ("arg_8324_8329", (8, TyVarFree NoLoc Lifted)), ("res_8325_8320", (8, TyVarFree NoLoc Lifted)), ("arg_8325_8323", (8, TyVarFree NoLoc Lifted)), ("res_8325_8324", (8, TyVarFree NoLoc Lifted)), ("arg_8325_8327", (8, TyVarFree NoLoc Lifted)), ("res_8325_8328", (8, TyVarFree NoLoc Lifted)), ("arg_8326_8321", (8, TyVarFree NoLoc Lifted)), ("res_8326_8322", (8, TyVarFree NoLoc Lifted)), ("arg_8326_8325", (8, TyVarFree NoLoc Lifted)), ("res_8326_8326", (8, TyVarFree NoLoc Lifted)), ("a_8326_8329", (4, TyVarFree NoLoc Lifted)), ("b_8327_8320", (4, TyVarFree NoLoc Lifted)), ("a_8327_8321", (4, TyVarFree NoLoc Unlifted)), ("x_8327_8322", (4, TyVarFree NoLoc Unlifted)), ("t_8327_8329", (5, TyVarFree NoLoc Lifted)), ("a_8328_8320", (6, TyVarFree NoLoc Lifted)), ("b_8328_8321", (6, TyVarFree NoLoc Lifted)), ("a_8328_8322", (6, TyVarFree NoLoc Unlifted)), ("x_8328_8323", (6, TyVarFree NoLoc Unlifted)), ("t_8328_8324", (7, TyVarFree NoLoc Lifted)), ("t_8328_8325", (8, TyVarFree NoLoc Unlifted)), ("num_8328_8326", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8323", (8, TyVarFree NoLoc Unlifted)), ("num_8329_8324", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8321", (8, TyVarFree NoLoc Unlifted)), ("num_8321_8320_8322", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8329", (8, TyVarFree NoLoc Unlifted)), ("num_8321_8321_8320", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8323", (7, TyVarFree NoLoc Lifted)), ("a_8321_8322_8324", (8, TyVarFree NoLoc Unlifted)), ("t_8321_8322_8325", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8328", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8329", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8323_8320", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8323_8321", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8328", (9, TyVarFree NoLoc Lifted)), ("t_8321_8323_8329", (10, TyVarFree NoLoc Unlifted)), ("num_8321_8324_8320", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8324_8325", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8324_8326", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8324_8327", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8324_8328", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8324_8329", (10, TyVarFree NoLoc Unlifted)), ("if_t_8321_8325_8322", (10, TyVarFree NoLoc SizeLifted)), ("t_8321_8325_8323", (11, TyVarFree NoLoc Lifted)), ("t_8321_8325_8324", (11, TyVarFree NoLoc Lifted)), ("t_8321_8325_8325", (11, TyVarFree NoLoc Lifted)), ("t_8321_8325_8326", (11, TyVarFree NoLoc Lifted)), ("t_8321_8325_8327", (13, TyVarFree NoLoc Lifted)), ("t_8321_8325_8328", (14, TyVarFree NoLoc Lifted)), ("t_8321_8325_8329", (14, TyVarFree NoLoc Lifted)), ("t_8321_8326_8320", (14, TyVarFree NoLoc Lifted)), ("t_8321_8326_8321", (14, TyVarFree NoLoc Lifted)), ("t_8321_8326_8322", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8323", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8324", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8325", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8326", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8327", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8328", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8326_8329", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8320", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8321", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8322", (15, TyVarFree NoLoc Unlifted)), ("num_8321_8327_8323", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8328", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8329", (15, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8329_8320", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8325", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8326", (15, TyVarFree NoLoc Unlifted)), ("num_8322_8320_8327", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322_8322", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322_8323", (15, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8322_8324", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8323_8329", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8320", (15, TyVarFree NoLoc Unlifted)), ("num_8322_8324_8321", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8325_8326", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8325_8327", (15, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8325_8328", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8327_8323", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8327_8324", (15, TyVarFree NoLoc Unlifted)), ("num_8322_8327_8325", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("a_8322_8329_8322", (12, TyVarFree NoLoc Unlifted)), ("b_8322_8329_8323", (12, TyVarFree NoLoc Unlifted)), ("x_8322_8329_8324", (12, TyVarFree NoLoc Unlifted)), ("t_8323_8320_8321", (13, TyVarFree NoLoc Lifted)), ("t_8323_8320_8322", (14, TyVarFree NoLoc Unlifted)), ("t_8323_8320_8323", (14, TyVarFree NoLoc Unlifted)), ("a_8323_8321_8322", (14, TyVarFree NoLoc Unlifted)), ("x_8323_8321_8323", (14, TyVarFree NoLoc Unlifted)), ("et_8323_8321_8324", (14, TyVarFree NoLoc Unlifted)), ("et_8323_8321_8329", (14, TyVarFree NoLoc Unlifted)), ("num_8323_8322_8320", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8322_8321", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8322_8326", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8322_8327", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8326", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8327", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8320_8321_8323_8323_8328", (13, TyVarFree NoLoc Lifted)), ("t_8323_8320_8323_8323_8323_8329", (14, TyVarFree NoLoc Unlifted)), ("t_8321_8323_8328_8323_8324_8320", (9, TyVarFree NoLoc Lifted)), ("t_8327_8329_8323_8324_8321", (5, TyVarFree NoLoc Lifted)), ("t_8321_8322_8323_8323_8324_8322", (7, TyVarFree NoLoc Lifted)), ("b_8328_8321_8323_8324_8323", (6, TyVarFree NoLoc Lifted)), ("a_8328_8320_8323_8324_8324", (6, TyVarFree NoLoc Lifted)), ("b_8327_8320_8323_8324_8325", (4, TyVarFree NoLoc Lifted)), ("a_8326_8329_8323_8324_8326", (4, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8328_8325" ~ "[]t_8328_8325_8322_8321_8328", + "t_8322_8321_8320" ~ "[]t_8322_8321_8320_8322_8321_8329", + "index_8321_8327_8326" ~ "[]index_8321_8327_8326_8322_8322_8320", + "t_8329_8322" ~ "[][]t_8329_8322_8322_8322_8321", + "index_elem_8321_8327_8327" ~ "[]index_elem_8321_8327_8327_8322_8322_8322", + "index_8321_8326_8329" ~ "[]index_8321_8326_8329_8322_8322_8323", + "t_8321_8323_8324" ~ "[][]t_8321_8323_8324_8322_8322_8324", + "index_elem_8321_8327_8320" ~ "[]index_elem_8321_8327_8320_8322_8322_8325", + "b_8329_8324" ~ "[][]b_8329_8324_8322_8322_8326", + "a_8329_8323" ~ "[][]a_8329_8323_8322_8322_8327", + "b_8329_8326" ~ "[][]b_8329_8326_8322_8322_8328", + "a_8329_8325" ~ "[]a_8329_8325_8322_8322_8329", + "b_8329_8328" ~ "[]b_8329_8328_8322_8323_8320", + "a_8329_8327" ~ "[]a_8329_8327_8322_8323_8321", + "b_8321_8320_8320" ~ "[]b_8321_8320_8320_8322_8323_8322", + "a_8329_8329" ~ "[][]a_8329_8329_8322_8323_8323", + "b_8321_8320_8322" ~ "[][]b_8321_8320_8322_8322_8323_8324", + "a_8321_8320_8321" ~ "[][]a_8321_8320_8321_8322_8323_8325", + "t_8326_8321" ~ "[][]t_8326_8321_8322_8323_8326", + "t_8328_8326" ~ "[]t_8328_8326_8322_8323_8327", + "et_8328_8327" ~ "[]et_8328_8327_8322_8323_8328", + "t_8325_8324" ~ "[][]t_8325_8324_8322_8323_8329", + "t_8322_8324" ~ "[]t_8322_8324_8322_8324_8320", + "t_8322_8322" ~ "[]t_8322_8322_8322_8324_8321", + "t_8325_8322" ~ "[][]t_8325_8322_8322_8324_8322", + "t_8325_8325" ~ "[]t_8325_8325_8322_8324_8323", + "et_8325_8326" ~ "[]et_8325_8326_8322_8324_8324", + "t_8325_8323" ~ "[][]t_8325_8323_8322_8324_8325", + "t_8322_8323" ~ "[]t_8322_8323_8322_8324_8326", + "a_8322_8325" ~ "[]a_8322_8325_8322_8324_8327", + "b_8322_8328" ~ "[]b_8322_8328_8322_8324_8328", + "a_8323_8322" ~ "[]a_8323_8322_8322_8324_8329", + "a_8322_8327" ~ "[][]a_8322_8327_8322_8325_8320", + "t_8321_8323" ~ "[]t_8321_8323_8322_8325_8321", + "t_8321_8324" ~ "[]t_8321_8324_8322_8325_8322", + "i64" ~ "t_8321", + "i64" ~ "t_8321", + "t_8321" ~ "t_8320", + "i64" ~ "t_8320", + "t_8320" ~ "i64", + "[]t_1" ~ "[]t_8321_8320", + "(t_8321_8323, t_8321_8324)" ~ "([]t_8321_8320, []t_8321_8320)", + "i32 -> t_1 -> i32" ~ "i32 -> t_8321_8325 -> i32", + "i32" ~ "i32", + "t_8321_8324" ~ "[]t_8321_8325", + "(t_8322_8322, t_8322_8323, t_8322_8324)" ~ "([]t_8321_8325, []i64, []i16)", + "t_8321_8323" ~ "[]t_8322_8329", + "i32 -> t_1 -> i32" ~ "i32 -> t_8323_8324 -> i32", + "i32" ~ "i32", + "{xs: []t_8323_8324} -> ([]t_8323_8324, []i64, []i16)" ~ "a_8323_8322 -> x_8323_8323", + "[][]t_8322_8329" ~ "a_8322_8327", + "{as: []a_8323_8322} -> *[]x_8323_8323" ~ "a_8322_8327 -> b_8322_8328", + "b_8322_8328" ~ "a_8322_8325", + "{xs: [](a_8324_8325, b_8324_8326, c_8324_8327)} -> ([]a_8324_8325, []b_8324_8326, []c_8324_8327)" ~ "a_8322_8325 -> b_8322_8326", + "(t_8325_8322, t_8325_8323, t_8325_8324)" ~ "b_8322_8326", + "t_8322_8323" ~ "et_8325_8326", + "t_8325_8323" ~ "[]t_8325_8325", + "[]et_8325_8326" ~ "[]t_8325_8325", + "t_8326_8321" ~ "[]t_8325_8325", + "i64" ~ "t_8326_8324", + "i64" ~ "t_8326_8324", + "t_8326_8324" ~ "t_8326_8323", + "i64" ~ "t_8326_8323", + "t_8325_8322" ~ "[][]t_8327_8324", + "[]t_8327_8324" ~ "[]t_8327_8323", + "t_8322_8322" ~ "[]t_8327_8323", + "t_8326_8323" ~ "i64", + "[]t_8327_8323" ~ "[]t_8326_8322", + "t_8328_8325" ~ "[]t_8326_8322", + "t_8322_8324" ~ "et_8328_8327", + "t_8325_8324" ~ "[]t_8328_8326", + "[]et_8328_8327" ~ "[]t_8328_8326", + "t_8329_8322" ~ "[]t_8328_8326", + "t_8326_8321" ~ "a_8321_8320_8321", + "{a: [][]t_8321_8320_8323} -> [][]t_8321_8320_8323" ~ "a_8321_8320_8321 -> b_8321_8320_8322", + "b_8321_8320_8322" ~ "a_8329_8329", + "{xs: [][]t_8321_8320_8328} -> []t_8321_8320_8328" ~ "a_8329_8329 -> b_8321_8320_8320", + "t_8321_8321_8324 -> t_8321_8321_8324 -> t_8321_8321_8324" ~ "update_elem_8321_8321_8323 -> update_elem_8321_8321_8323 -> update_elem_8321_8321_8323", + "num_8321_8321_8325" ~ "update_elem_8321_8321_8323", + "b_8321_8320_8320" ~ "a_8329_8327", + "{xs: []update_elem_8321_8321_8323} -> *[]update_elem_8321_8321_8323" ~ "a_8329_8327 -> b_8329_8328", + "b_8329_8328" ~ "a_8329_8325", + "{xs: []t_8321_8322_8324} -> [][]t_8321_8322_8324" ~ "a_8329_8325 -> b_8329_8326", + "b_8329_8326" ~ "a_8329_8323", + "{a: [][]t_8321_8322_8329} -> [][]t_8321_8322_8329" ~ "a_8329_8323 -> b_8329_8324", + "t_8321_8323_8324" ~ "b_8329_8324", + "i64" ~ "t_8321_8323_8327", + "i64" ~ "t_8321_8323_8327", + "t_8321_8323_8327" ~ "t_8321_8323_8326", + "i64" ~ "t_8321_8323_8326", + "t_8321_8324_8327" ~ "t_8321_8324_8326", + "index_8321_8324_8328" ~ "index_elem_8321_8324_8329", + "t_8328_8325" ~ "[]index_elem_8321_8324_8329", + "t_8321_8325_8320" ~ "index_8321_8324_8328", + "num_8321_8325_8322" ~ "i32", + "i32 -> t_1 -> i32" ~ "i32 -> t_8321_8325_8321 -> i32", + "i32" ~ "i32", + "t_8321_8325_8320" ~ "t_8321_8325_8321", + "t_8321_8326_8321" ~ "i64", + "t_8321_8324_8326" ~ "t_8321_8326_8322", + "i64" ~ "t_8321_8326_8322", + "t_8321_8326_8327" ~ "t_8321_8326_8322", + "t_8321_8326_8328" ~ "t_8321_8326_8327", + "index_8321_8326_8329" ~ "index_elem_8321_8327_8320", + "t_8321_8323_8324" ~ "[]index_elem_8321_8327_8320", + "t_8321_8327_8321" ~ "t_8321_8326_8321", + "index_8321_8327_8322" ~ "index_elem_8321_8327_8323", + "index_8321_8326_8329" ~ "[]index_elem_8321_8327_8323", + "t_8321_8327_8324" ~ "index_8321_8327_8322", + "t_8321_8327_8325" ~ "t_8321_8326_8327", + "index_8321_8327_8326" ~ "index_elem_8321_8327_8327", + "t_8329_8322" ~ "[]index_elem_8321_8327_8327", + "t_8321_8327_8328" ~ "t_8321_8326_8321", + "index_8321_8327_8329" ~ "index_elem_8321_8328_8320", + "index_8321_8327_8326" ~ "[]index_elem_8321_8328_8320", + "index_8321_8327_8329" ~ "i16", + "t_8321_8328_8323" ~ "i64", + "i64" ~ "t_8321_8328_8325", + "t_8321_8326_8327" ~ "t_8321_8328_8325", + "t_8321_8328_8325" ~ "t_8321_8328_8324", + "t_8321_8328_8323" ~ "t_8321_8328_8324", + "t_8321_8329_8324" ~ "t_8321_8328_8324", + "t_8321_8324_8326" ~ "t_8321_8329_8326", + "t_8321_8329_8324" ~ "t_8321_8329_8326", + "t_8321_8329_8326" ~ "t_8321_8329_8325", + "t_8321_8327_8324" ~ "t_8321_8329_8325", + "t_8322_8320_8325" ~ "t_8321_8329_8325", + "t_8321_8323_8326" ~ "i64", + "{i: t_8321_8324_8326} -> t_8322_8320_8325" ~ "i64 -> a_8321_8323_8325", + "t_8322_8321_8320" ~ "[]a_8321_8323_8325", + "[]t_1" ~ "[]t_8322_8321_8321", + "t_8322_8321_8320" ~ "[]i64", + "t_8328_8325" ~ "[]t_8322_8321_8321" + ], + M.fromList [("t_1", (0, Unlifted, NoLoc))], + M.fromList [("t_8320", (3, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321", (3, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320", (4, TyVarFree NoLoc Unlifted)), ("t_8321_8323", (5, TyVarFree NoLoc Lifted)), ("t_8321_8324", (5, TyVarFree NoLoc Lifted)), ("t_8321_8325", (6, TyVarFree NoLoc Unlifted)), ("t_8322_8322", (7, TyVarFree NoLoc Lifted)), ("t_8322_8323", (7, TyVarFree NoLoc Lifted)), ("t_8322_8324", (7, TyVarFree NoLoc Lifted)), ("a_8322_8325", (8, TyVarFree NoLoc Lifted)), ("b_8322_8326", (8, TyVarFree NoLoc Lifted)), ("a_8322_8327", (8, TyVarFree NoLoc Lifted)), ("b_8322_8328", (8, TyVarFree NoLoc Lifted)), ("t_8322_8329", (8, TyVarFree NoLoc Unlifted)), ("a_8323_8322", (8, TyVarFree NoLoc Unlifted)), ("x_8323_8323", (8, TyVarFree NoLoc Unlifted)), ("t_8323_8324", (8, TyVarFree NoLoc Unlifted)), ("a_8324_8325", (8, TyVarFree NoLoc Unlifted)), ("b_8324_8326", (8, TyVarFree NoLoc Unlifted)), ("c_8324_8327", (8, TyVarFree NoLoc Unlifted)), ("t_8325_8322", (9, TyVarFree NoLoc Lifted)), ("t_8325_8323", (9, TyVarFree NoLoc Lifted)), ("t_8325_8324", (9, TyVarFree NoLoc Lifted)), ("t_8325_8325", (10, TyVarFree NoLoc Unlifted)), ("et_8325_8326", (10, TyVarFree NoLoc Unlifted)), ("t_8326_8321", (11, TyVarFree NoLoc Lifted)), ("t_8326_8322", (12, TyVarFree NoLoc Unlifted)), ("t_8326_8323", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8326_8324", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327_8323", (12, TyVarFree NoLoc Unlifted)), ("t_8327_8324", (12, TyVarFree NoLoc Unlifted)), ("t_8328_8325", (13, TyVarFree NoLoc Lifted)), ("t_8328_8326", (14, TyVarFree NoLoc Unlifted)), ("et_8328_8327", (14, TyVarFree NoLoc Unlifted)), ("t_8329_8322", (15, TyVarFree NoLoc Lifted)), ("a_8329_8323", (16, TyVarFree NoLoc Lifted)), ("b_8329_8324", (16, TyVarFree NoLoc Lifted)), ("a_8329_8325", (16, TyVarFree NoLoc Lifted)), ("b_8329_8326", (16, TyVarFree NoLoc Lifted)), ("a_8329_8327", (16, TyVarFree NoLoc Lifted)), ("b_8329_8328", (16, TyVarFree NoLoc Lifted)), ("a_8329_8329", (16, TyVarFree NoLoc Lifted)), ("b_8321_8320_8320", (16, TyVarFree NoLoc Lifted)), ("a_8321_8320_8321", (16, TyVarFree NoLoc Lifted)), ("b_8321_8320_8322", (16, TyVarFree NoLoc Lifted)), ("t_8321_8320_8323", (16, TyVarFree NoLoc Unlifted)), ("t_8321_8320_8328", (16, TyVarFree NoLoc Unlifted)), ("update_elem_8321_8321_8323", (16, TyVarFree NoLoc Unlifted)), ("t_8321_8321_8324", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8321_8325", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8324", (16, TyVarFree NoLoc Unlifted)), ("t_8321_8322_8329", (16, TyVarFree NoLoc Unlifted)), ("t_8321_8323_8324", (17, TyVarFree NoLoc Lifted)), ("a_8321_8323_8325", (18, TyVarFree NoLoc Unlifted)), ("t_8321_8323_8326", (18, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8327", (18, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8324_8326", (19, TyVarFree NoLoc Lifted)), ("t_8321_8324_8327", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8324_8328", (20, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8324_8329", (20, TyVarFree NoLoc Unlifted)), ("t_8321_8325_8320", (21, TyVarFree NoLoc Lifted)), ("t_8321_8325_8321", (22, TyVarFree NoLoc Unlifted)), ("num_8321_8325_8322", (22, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8321", (23, TyVarFree NoLoc Lifted)), ("t_8321_8326_8322", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8327", (25, TyVarFree NoLoc Lifted)), ("t_8321_8326_8328", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8326_8329", (26, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8327_8320", (26, TyVarFree NoLoc Unlifted)), ("t_8321_8327_8321", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8327_8322", (26, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8327_8323", (26, TyVarFree NoLoc Unlifted)), ("t_8321_8327_8324", (27, TyVarFree NoLoc Lifted)), ("t_8321_8327_8325", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8327_8326", (28, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8327_8327", (28, TyVarFree NoLoc Unlifted)), ("t_8321_8327_8328", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8327_8329", (28, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8328_8320", (28, TyVarFree NoLoc Unlifted)), ("t_8321_8328_8323", (29, TyVarFree NoLoc Lifted)), ("t_8321_8328_8324", (30, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8325", (30, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8324", (31, TyVarFree NoLoc Lifted)), ("t_8321_8329_8325", (32, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8326", (32, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8325", (33, TyVarFree NoLoc Lifted)), ("t_8322_8321_8320", (19, TyVarFree NoLoc Lifted)), ("t_8322_8321_8321", (20, TyVarFree NoLoc Unlifted)), ("t_8328_8325_8322_8321_8328", (13, TyVarFree NoLoc Lifted)), ("t_8322_8321_8320_8322_8321_8329", (19, TyVarFree NoLoc Lifted)), ("index_8321_8327_8326_8322_8322_8320", (28, TyVarFree NoLoc Unlifted)), ("t_8329_8322_8322_8322_8321", (15, TyVarFree NoLoc Lifted)), ("index_elem_8321_8327_8327_8322_8322_8322", (28, TyVarFree NoLoc Unlifted)), ("index_8321_8326_8329_8322_8322_8323", (26, TyVarFree NoLoc Unlifted)), ("t_8321_8323_8324_8322_8322_8324", (17, TyVarFree NoLoc Lifted)), ("index_elem_8321_8327_8320_8322_8322_8325", (26, TyVarFree NoLoc Unlifted)), ("b_8329_8324_8322_8322_8326", (16, TyVarFree NoLoc Lifted)), ("a_8329_8323_8322_8322_8327", (16, TyVarFree NoLoc Lifted)), ("b_8329_8326_8322_8322_8328", (16, TyVarFree NoLoc Lifted)), ("a_8329_8325_8322_8322_8329", (16, TyVarFree NoLoc Lifted)), ("b_8329_8328_8322_8323_8320", (16, TyVarFree NoLoc Lifted)), ("a_8329_8327_8322_8323_8321", (16, TyVarFree NoLoc Lifted)), ("b_8321_8320_8320_8322_8323_8322", (16, TyVarFree NoLoc Lifted)), ("a_8329_8329_8322_8323_8323", (16, TyVarFree NoLoc Lifted)), ("b_8321_8320_8322_8322_8323_8324", (16, TyVarFree NoLoc Lifted)), ("a_8321_8320_8321_8322_8323_8325", (16, TyVarFree NoLoc Lifted)), ("t_8326_8321_8322_8323_8326", (11, TyVarFree NoLoc Lifted)), ("t_8328_8326_8322_8323_8327", (14, TyVarFree NoLoc Unlifted)), ("et_8328_8327_8322_8323_8328", (14, TyVarFree NoLoc Unlifted)), ("t_8325_8324_8322_8323_8329", (9, TyVarFree NoLoc Lifted)), ("t_8322_8324_8322_8324_8320", (7, TyVarFree NoLoc Lifted)), ("t_8322_8322_8322_8324_8321", (7, TyVarFree NoLoc Lifted)), ("t_8325_8322_8322_8324_8322", (9, TyVarFree NoLoc Lifted)), ("t_8325_8325_8322_8324_8323", (10, TyVarFree NoLoc Unlifted)), ("et_8325_8326_8322_8324_8324", (10, TyVarFree NoLoc Unlifted)), ("t_8325_8323_8322_8324_8325", (9, TyVarFree NoLoc Lifted)), ("t_8322_8323_8322_8324_8326", (7, TyVarFree NoLoc Lifted)), ("a_8322_8325_8322_8324_8327", (8, TyVarFree NoLoc Lifted)), ("b_8322_8328_8322_8324_8328", (8, TyVarFree NoLoc Lifted)), ("a_8323_8322_8322_8324_8329", (8, TyVarFree NoLoc Unlifted)), ("a_8322_8327_8322_8325_8320", (8, TyVarFree NoLoc Lifted)), ("t_8321_8323_8322_8325_8321", (5, TyVarFree NoLoc Lifted)), ("t_8321_8324_8322_8325_8322", (5, TyVarFree NoLoc Lifted))] + ), + ( [ "b_8325_8329" ~ "[]b_8325_8329_8328_8325", + "t_8326_8326" ~ "[]t_8326_8326_8328_8326", + "a_8325_8328" ~ "[]a_8325_8328_8328_8327", + "t_8326_8323" ~ "[]t_8326_8323_8328_8328", + "t_8325_8327" ~ "[]t_8325_8327_8328_8329", + "i64" ~ "t_8320", + "num_8321" ~ "t_8320", + "i32" ~ "t_8329", + "num_8321_8320" ~ "t_8329", + "t_8329" ~ "t_8328", + "num_8321_8325" ~ "t_8328", + "t_8328" ~ "t_8327", + "num_8322_8320" ~ "t_8327", + "bool" ~ "bool", + "num_8326" ~ "if_t_8322_8325", + "t_8327" ~ "if_t_8322_8325", + "t_8322_8326" ~ "if_t_8322_8325", + "i16" ~ "i16", + "t_8322_8329" ~ "i64", + "i64" ~ "t_8323_8320", + "t_8322_8329" ~ "t_8323_8320", + "t_8323_8325" ~ "t_8323_8320", + "i64" ~ "t_8323_8326", + "t_8322_8329" ~ "t_8323_8326", + "t_8324_8321" ~ "t_8323_8326", + "t_8323_8325" ~ "t_8324_8324", + "t_8322_8329" ~ "t_8324_8324", + "t_8324_8324" ~ "t_8324_8323", + "t_8324_8321" ~ "t_8324_8323", + "t_8324_8323" ~ "i64", + "[]t_1" ~ "[]t_8324_8322", + "t_8325_8327" ~ "[]t_8324_8322", + "i64" ~ "i64", + "t_8325_8327" ~ "t_8326_8323", + "t_8326_8326" ~ "t_8326_8323", + "t_8326_8327" ~ "t_8322_8326", + "t_8322_8326" ~ "t_8326_8329", + "num_8327_8320" ~ "t_8326_8329", + "i32 -> t_1 -> i32" ~ "i32 -> t_8326_8328 -> i32", + "t_8326_8329" ~ "i32", + "t_8326_8326" ~ "[]t_8326_8328", + "t_8326_8323" ~ "[]t_8326_8328", + "{xs: []t_8326_8320} -> []t_8326_8320" ~ "a_8325_8328 -> b_8325_8329", + "t_8326_8326" ~ "a_8325_8328", + "[]t_1" ~ "b_8325_8329" + ], + M.fromList [("t_1", (0, Unlifted, NoLoc))], + M.fromList [("t_8320", (5, TyVarFree NoLoc Unlifted)), ("num_8321", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8326", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8328", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8320", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8325", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8320", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8322_8325", (5, TyVarFree NoLoc SizeLifted)), ("t_8322_8326", (6, TyVarFree NoLoc Lifted)), ("t_8322_8329", (8, TyVarFree NoLoc Lifted)), ("t_8323_8320", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8325", (10, TyVarFree NoLoc Lifted)), ("t_8323_8326", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8321", (12, TyVarFree NoLoc Lifted)), ("t_8324_8322", (13, TyVarFree NoLoc Unlifted)), ("t_8324_8323", (13, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8324", (13, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8327", (14, TyVarFree NoLoc Lifted)), ("a_8325_8328", (15, TyVarFree NoLoc Lifted)), ("b_8325_8329", (15, TyVarFree NoLoc Lifted)), ("t_8326_8320", (15, TyVarFree NoLoc Unlifted)), ("t_8326_8323", (15, TyVarFree NoLoc Unlifted)), ("t_8326_8326", (15, TyVarFree NoLoc Lifted)), ("t_8326_8327", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64])), ("t_8326_8328", (15, TyVarFree NoLoc Unlifted)), ("t_8326_8329", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8327_8320", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("b_8325_8329_8328_8325", (15, TyVarFree NoLoc Lifted)), ("t_8326_8326_8328_8326", (15, TyVarFree NoLoc Lifted)), ("a_8325_8328_8328_8327", (15, TyVarFree NoLoc Lifted)), ("t_8326_8323_8328_8328", (15, TyVarFree NoLoc Unlifted)), ("t_8325_8327_8328_8329", (14, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8324_8325_8320" ~ "[]t_8324_8325_8320_8324_8326_8326", + "t_8323_8327_8320" ~ "[]t_8323_8327_8320_8324_8326_8327", + "t_8323_8327_8322" ~ "[]t_8323_8327_8322_8324_8326_8328", + "t_8323_8327_8321" ~ "[]t_8323_8327_8321_8324_8326_8329", + "a_8323_8325_8320" ~ "[]a_8323_8325_8320_8324_8327_8320", + "t_8321" ~ "t_8323", + "num_8324" ~ "t_8323", + "i64" ~ "i64", + "t_8321" ~ "t_8329", + "i32" ~ "t_8329", + "bool" ~ "t_8322", + "bool" ~ "t_8322", + "t_8322_8320" ~ "t_8320", + "index_8322_8321" ~ "index_elem_8322_8322", + "[]u32" ~ "[]index_elem_8322_8322", + "t_8322_8323" ~ "index_8322_8321", + "t_8322_8324" ~ "t_8321", + "index_8322_8325" ~ "index_elem_8322_8326", + "[]u32" ~ "[]index_elem_8322_8326", + "t_8322_8327" ~ "index_8322_8325", + "t_8322_8323" ~ "t_8322_8328", + "t_8322_8327" ~ "t_8322_8328", + "t_8320" ~ "i32", + "t_8321" ~ "i32", + "u32" ~ "t_8323_8325", + "u32" ~ "t_8323_8325", + "t_8323_8325" ~ "u32", + "num_8323_8324" ~ "t_8323_8323", + "i32" ~ "t_8323_8323", + "t_8322_8323" ~ "t_8325_8320", + "t_8322_8327" ~ "t_8325_8320", + "t_8325_8320" ~ "u32", + "bool" ~ "bool", + "t_8323_8323" ~ "if_t_8325_8327", + "i32" ~ "if_t_8325_8327", + "t_8325_8329" ~ "num_8325_8328", + "t_8322" ~ "bool", + "if_t_8325_8327" ~ "if_t_8326_8320", + "num_8325_8328" ~ "if_t_8326_8320", + "t_8326_8321" ~ "i64", + "t_8326_8324" ~ "i32", + "t_8326_8324" ~ "t_8326_8326", + "num_8326_8327" ~ "t_8326_8326", + "(t_8326_8324, t_8326_8326)" ~ "(t_8320, t_8321)", + "t_8326_8324" ~ "t_8327_8324", + "num_8327_8325" ~ "t_8327_8324", + "(t_8326_8324, t_8327_8324)" ~ "(t_8320, t_8321)", + "if_t_8326_8320" ~ "t_8326_8325", + "if_t_8326_8320" ~ "t_8326_8325", + "t_8326_8325" ~ "i32", + "t_8328_8328" ~ "i32", + "t_8326_8324" ~ "t_8328_8329", + "t_8328_8328" ~ "t_8328_8329", + "(t_8326_8324, t_8328_8329)" ~ "(t_8320, t_8321)", + "t_8329_8326" ~ "if_t_8326_8320", + "t_8329_8328" ~ "num_8329_8327", + "t_8329_8328" ~ "t_8321_8320_8321", + "t_8328_8328" ~ "t_8321_8320_8321", + "t_8326_8324" ~ "t_8321_8320_8320", + "t_8321_8320_8321" ~ "t_8321_8320_8320", + "(t_8326_8324, t_8321_8320_8320)" ~ "(t_8320, t_8321)", + "if_t_8326_8320" ~ "t_8329_8329", + "t_8329_8326" ~ "t_8329_8329", + "t_8329_8328" ~ "t_8321_8321_8326", + "num_8321_8321_8327" ~ "t_8321_8321_8326", + "num_8329_8327" ~ "t_8321_8321_8326", + "t_8321_8322_8322" ~ "t_8329_8328", + "t_8321_8322_8322" ~ "t_8321_8322_8324", + "num_8321_8322_8325" ~ "t_8321_8322_8324", + "(t_8321_8323_8320, t_8321_8323_8321)" ~ "(num_8321_8322_8323, t_8321_8322_8324)", + "t_8321_8323_8321" ~ "t_8321_8323_8322", + "num_8321_8323_8323" ~ "t_8321_8323_8322", + "t_8321_8323_8320" ~ "t_8321_8324_8321", + "t_8321_8323_8321" ~ "t_8321_8324_8321", + "t_8321_8324_8321" ~ "t_8321_8324_8320", + "t_8328_8328" ~ "t_8321_8324_8320", + "t_8326_8324" ~ "t_8321_8323_8329", + "t_8321_8324_8320" ~ "t_8321_8323_8329", + "(t_8326_8324, t_8321_8323_8329)" ~ "(t_8320, t_8321)", + "if_t_8326_8320" ~ "t_8321_8323_8328", + "t_8329_8326" ~ "t_8321_8323_8328", + "t_8321_8323_8320" ~ "t_8321_8326_8320", + "t_8321_8323_8321" ~ "t_8321_8326_8320", + "t_8321_8323_8321" ~ "t_8321_8326_8325", + "num_8321_8326_8326" ~ "t_8321_8326_8325", + "t_8321_8323_8321" ~ "t_8321_8327_8321", + "num_8321_8327_8322" ~ "t_8321_8327_8321", + "bool" ~ "bool", + "(t_8321_8326_8320, t_8321_8326_8325)" ~ "if_t_8321_8327_8327", + "(t_8321_8323_8320, t_8321_8327_8321)" ~ "if_t_8321_8327_8327", + "(num_8321_8322_8323, t_8321_8322_8324)" ~ "if_t_8321_8327_8327", + "(t_8321_8327_8328, t_8321_8327_8329)" ~ "(t_8321_8323_8320, t_8321_8323_8321)", + "t_8321_8327_8328" ~ "t_8321_8328_8321", + "t_8328_8328" ~ "t_8321_8328_8321", + "t_8326_8324" ~ "t_8321_8328_8320", + "t_8321_8328_8321" ~ "t_8321_8328_8320", + "t_8321_8329_8320" ~ "t_8321_8328_8320", + "(t_8326_8324, t_8321_8329_8320)" ~ "(t_8320, t_8321)", + "t_8321_8329_8323" ~ "if_t_8326_8320", + "(t_8321_8329_8326, t_8321_8329_8327)" ~ "(num_8321_8329_8324, num_8321_8329_8325)", + "t_8321_8329_8327" ~ "t_8321_8329_8328", + "t_8321_8327_8328" ~ "t_8321_8329_8328", + "t_8321_8329_8327" ~ "t_8322_8320_8323", + "num_8322_8320_8324" ~ "t_8322_8320_8323", + "t_8321_8327_8328" ~ "i32", + "t_8322_8320_8323" ~ "i32", + "t_8322_8321_8323" ~ "i32", + "t_8321_8329_8326" ~ "t_8322_8321_8327", + "t_8322_8321_8323" ~ "t_8322_8321_8327", + "t_8322_8321_8327" ~ "t_8322_8321_8326", + "t_8328_8328" ~ "t_8322_8321_8326", + "t_8326_8324" ~ "t_8322_8321_8325", + "t_8322_8321_8326" ~ "t_8322_8321_8325", + "(t_8326_8324, t_8322_8321_8325)" ~ "(t_8320, t_8321)", + "if_t_8326_8320" ~ "t_8322_8321_8324", + "t_8321_8329_8323" ~ "t_8322_8321_8324", + "t_8321_8329_8326" ~ "t_8322_8323_8326", + "t_8322_8321_8323" ~ "t_8322_8323_8326", + "t_8321_8329_8327" ~ "t_8322_8324_8321", + "num_8322_8324_8322" ~ "t_8322_8324_8321", + "t_8321_8329_8327" ~ "t_8322_8324_8327", + "num_8322_8324_8328" ~ "t_8322_8324_8327", + "bool" ~ "bool", + "(t_8322_8323_8326, t_8322_8324_8321)" ~ "if_t_8322_8325_8323", + "(t_8321_8329_8326, t_8322_8324_8327)" ~ "if_t_8322_8325_8323", + "(num_8321_8329_8324, num_8321_8329_8325)" ~ "if_t_8322_8325_8323", + "(t_8322_8325_8324, t_8322_8325_8325)" ~ "(t_8321_8329_8326, t_8321_8329_8327)", + "t_8322_8325_8324" ~ "t_8322_8325_8328", + "t_8328_8328" ~ "t_8322_8325_8328", + "t_8326_8324" ~ "t_8322_8325_8327", + "t_8322_8325_8328" ~ "t_8322_8325_8327", + "t_8328_8328" ~ "i32", + "num_8322_8326_8327" ~ "i32", + "t_8322_8325_8327" ~ "t_8322_8325_8326", + "i32" ~ "t_8322_8325_8326", + "t_8322_8327_8326" ~ "t_8322_8325_8326", + "t_8326_8324" ~ "i32", + "t_8321_8329_8320" ~ "i32", + "i32" ~ "t_8322_8327_8327", + "t_8322_8327_8326" ~ "t_8322_8327_8327", + "t_8322_8328_8328" ~ "num_8322_8328_8327", + "bool" ~ "bool", + "(t_8322_8328_8326, num_8322_8328_8327)" ~ "if_t_8322_8329_8320", + "(t_8322_8328_8329, t_8322_8327_8326)" ~ "if_t_8322_8329_8320", + "(t_8322_8329_8321, t_8322_8329_8322)" ~ "if_t_8322_8329_8320", + "t_8326_8324" ~ "i32", + "t_8321_8329_8320" ~ "i32", + "t_8322_8327_8326" ~ "t_8322_8329_8328", + "num_8322_8329_8329" ~ "t_8322_8329_8328", + "i32" ~ "t_8322_8329_8323", + "t_8322_8329_8328" ~ "t_8322_8329_8323", + "t_8322_8327_8326" ~ "t_8323_8320_8328", + "num_8323_8320_8329" ~ "t_8323_8320_8328", + "t_8323_8321_8326" ~ "num_8323_8321_8325", + "t_8322_8327_8326" ~ "t_8323_8321_8327", + "num_8323_8321_8328" ~ "t_8323_8321_8327", + "t_8322_8327_8326" ~ "t_8323_8322_8324", + "num_8323_8322_8325" ~ "t_8323_8322_8324", + "bool" ~ "bool", + "(t_8323_8321_8324, num_8323_8321_8325)" ~ "if_t_8323_8323_8320", + "(t_8323_8322_8323, t_8323_8322_8324)" ~ "if_t_8323_8323_8320", + "(t_8323_8323_8321, t_8323_8323_8322)" ~ "if_t_8323_8323_8320", + "t_8323_8323_8324" ~ "t_8326_8324", + "index_8323_8323_8325" ~ "index_elem_8323_8323_8326", + "[]u32" ~ "[]index_elem_8323_8323_8326", + "t_8321_8329_8323" ~ "i32", + "num_8323_8323_8328" ~ "t_8323_8323_8327", + "u32" ~ "t_8323_8323_8327", + "index_8323_8323_8325" ~ "t_8323_8323_8323", + "t_8323_8323_8327" ~ "t_8323_8323_8323", + "t_8323_8324_8329" ~ "t_8323_8323_8323", + "i64" ~ "t_8323_8325_8323", + "num_8323_8325_8324" ~ "t_8323_8325_8323", + "t_8323_8325_8323" ~ "i64", + "{i: t_8326_8321} -> ({delta_node: t_8321_8329_8323, left: t_8322_8329_8321, right: t_8323_8323_8321, sfc_code: t_8323_8324_8329}, (t_8322_8329_8322, t_8326_8324), (t_8323_8323_8322, t_8326_8324))" ~ "i64 -> a_8323_8325_8322", + "[]a_8323_8325_8322" ~ "a_8323_8325_8320", + "{xs: [](a_8323_8326_8323, b_8323_8326_8324, c_8323_8326_8325)} -> ([]a_8323_8326_8323, []b_8323_8326_8324, []c_8323_8326_8325)" ~ "a_8323_8325_8320 -> b_8323_8325_8321", + "(t_8323_8327_8320, t_8323_8327_8321, t_8323_8327_8322)" ~ "b_8323_8325_8321", + "i64" ~ "t_8323_8327_8324", + "num_8323_8327_8325" ~ "t_8323_8327_8324", + "t_8323_8328_8321" ~ "num_8323_8328_8320", + "a_8323_8328_8328" ~ "ft_8323_8329_8320", + "a_8323_8328_8328 -> b_8323_8328_8329" ~ "a_8323_8328_8325 -> b_8323_8328_8326", + "i32 -> i64" ~ "b_8323_8328_8326 -> c_8323_8328_8327", + "{x: a_8323_8328_8325} -> c_8323_8328_8327" ~ "a_8323_8328_8323 -> x_8323_8328_8324", + "t_8323_8327_8321" ~ "[]a_8323_8328_8323", + "a_8324_8320_8324" ~ "ft_8324_8320_8326", + "a_8324_8320_8324 -> b_8324_8320_8325" ~ "a_8324_8320_8321 -> b_8324_8320_8322", + "i32 -> i64" ~ "b_8324_8320_8322 -> c_8324_8320_8323", + "{x: a_8324_8320_8321} -> c_8324_8320_8323" ~ "a_8323_8329_8329 -> x_8324_8320_8320", + "t_8323_8327_8322" ~ "[]a_8323_8329_8329", + "[]x_8323_8328_8324" ~ "[]t_8323_8328_8322", + "[]x_8324_8320_8320" ~ "[]t_8323_8328_8322", + "a_8324_8322_8322" ~ "ft_8324_8322_8324", + "a_8324_8322_8322 -> b_8324_8322_8323" ~ "a_8324_8322_8320 -> x_8324_8322_8321", + "t_8323_8327_8321" ~ "[]a_8324_8322_8320", + "a_8324_8323_8321" ~ "ft_8324_8323_8323", + "a_8324_8323_8321 -> b_8324_8323_8322" ~ "a_8324_8322_8329 -> x_8324_8323_8320", + "t_8323_8327_8322" ~ "[]a_8324_8322_8329", + "[]x_8324_8322_8321" ~ "[]t_8324_8321_8329", + "[]x_8324_8323_8320" ~ "[]t_8324_8321_8329", + "t_8323_8327_8324" ~ "i64", + "num_8323_8328_8320" ~ "t_8323_8327_8323", + "[]t_8323_8328_8322" ~ "[]i64", + "[]t_8324_8321_8329" ~ "[]t_8323_8327_8323", + "t_8324_8325_8320" ~ "[]t_8323_8327_8323", + "{delta_node: t_8324_8325_8324, left: t_8324_8325_8325, right: t_8324_8325_8326, sfc_code: t_8324_8325_8327} -> {parent: t_8324_8325_8328} -> {delta_node: t_8324_8325_8324, left: t_8324_8325_8325, parent: t_8324_8325_8328, right: t_8324_8325_8326, sfc_code: t_8324_8325_8327}" ~ "a_8324_8325_8321 -> b_8324_8325_8322 -> x_8324_8325_8323", + "t_8323_8327_8320" ~ "[]a_8324_8325_8321", + "t_8324_8325_8320" ~ "[]b_8324_8325_8322", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]x_8324_8325_8323" + ], + M.empty, + M.fromList [("t_8320", (3, TyVarFree NoLoc Lifted)), ("t_8321", (3, TyVarFree NoLoc Lifted)), ("t_8322", (4, TyVarPrim NoLoc [Bool])), ("t_8323", (4, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329", (4, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8321", (4, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8322", (4, TyVarFree NoLoc Unlifted)), ("t_8322_8323", (5, TyVarFree NoLoc Lifted)), ("t_8322_8324", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8325", (6, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8326", (6, TyVarFree NoLoc Unlifted)), ("t_8322_8327", (7, TyVarFree NoLoc Lifted)), ("t_8322_8328", (8, TyVarFree NoLoc Unlifted)), ("t_8323_8323", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8324", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8325", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64])), ("t_8325_8320", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64])), ("if_t_8325_8327", (8, TyVarFree NoLoc SizeLifted)), ("num_8325_8328", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8329", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8326_8320", (4, TyVarFree NoLoc SizeLifted)), ("t_8326_8321", (3, TyVarFree NoLoc Lifted)), ("t_8326_8324", (5, TyVarFree NoLoc Lifted)), ("t_8326_8325", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8326_8326", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8326_8327", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327_8324", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8327_8325", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8328_8328", (7, TyVarFree NoLoc Lifted)), ("t_8328_8329", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8326", (9, TyVarFree NoLoc Lifted)), ("num_8329_8327", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8328", (10, TyVarFree NoLoc Lifted)), ("t_8329_8329", (10, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8320", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8321", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8321_8326", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8321_8327", (10, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8322", (11, TyVarFree NoLoc Lifted)), ("num_8321_8322_8323", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8324", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8322_8325", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8320", (12, TyVarFree NoLoc Lifted)), ("t_8321_8323_8321", (12, TyVarFree NoLoc Lifted)), ("t_8321_8323_8322", (12, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8323_8323", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8328", (12, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8329", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8324_8320", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8324_8321", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8320", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8325", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8326_8326", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8321", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8327_8322", (12, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8321_8327_8327", (12, TyVarFree NoLoc SizeLifted)), ("t_8321_8327_8328", (13, TyVarFree NoLoc Lifted)), ("t_8321_8327_8329", (13, TyVarFree NoLoc Lifted)), ("t_8321_8328_8320", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8321", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8320", (15, TyVarFree NoLoc Lifted)), ("t_8321_8329_8323", (17, TyVarFree NoLoc Lifted)), ("num_8321_8329_8324", (18, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8329_8325", (18, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8326", (18, TyVarFree NoLoc Lifted)), ("t_8321_8329_8327", (18, TyVarFree NoLoc Lifted)), ("t_8321_8329_8328", (18, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8323", (18, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8320_8324", (18, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8321_8323", (19, TyVarFree NoLoc Lifted)), ("t_8322_8321_8324", (20, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8321_8325", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8321_8326", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8321_8327", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8323_8326", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8321", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8324_8322", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8327", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8324_8328", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8322_8325_8323", (20, TyVarFree NoLoc SizeLifted)), ("t_8322_8325_8324", (19, TyVarFree NoLoc Lifted)), ("t_8322_8325_8325", (19, TyVarFree NoLoc Lifted)), ("t_8322_8325_8326", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8325_8327", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8325_8328", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8326_8327", (20, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8327_8326", (21, TyVarFree NoLoc Lifted)), ("t_8322_8327_8327", (22, TyVarFree NoLoc Unlifted)), ("t_8322_8328_8326", (22, TyVarSum NoLoc (M.fromList [("leaf", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8322_8327_8326" 14782}) [])])]))), ("num_8322_8328_8327", (22, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8328_8328", (22, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8328_8329", (22, TyVarSum NoLoc (M.fromList [("inner", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8322_8327_8326" 14782}) [])])]))), ("if_t_8322_8329_8320", (22, TyVarFree NoLoc SizeLifted)), ("t_8322_8329_8321", (23, TyVarFree NoLoc Lifted)), ("t_8322_8329_8322", (23, TyVarFree NoLoc Lifted)), ("t_8322_8329_8323", (24, TyVarFree NoLoc Unlifted)), ("t_8322_8329_8328", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8329_8329", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8320_8328", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8320_8329", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8321_8324", (24, TyVarSum NoLoc (M.fromList [("leaf", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8323_8320_8328" 14819}) [])])]))), ("num_8323_8321_8325", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8321_8326", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8321_8327", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8321_8328", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8322_8323", (24, TyVarSum NoLoc (M.fromList [("inner", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8323_8321_8327" 14829}) [])])]))), ("t_8323_8322_8324", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8322_8325", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8323_8323_8320", (24, TyVarFree NoLoc SizeLifted)), ("t_8323_8323_8321", (25, TyVarFree NoLoc Lifted)), ("t_8323_8323_8322", (25, TyVarFree NoLoc Lifted)), ("t_8323_8323_8323", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64])), ("t_8323_8323_8324", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8323_8325", (26, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8323_8326", (26, TyVarFree NoLoc Unlifted)), ("t_8323_8323_8327", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8328", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8324_8329", (27, TyVarFree NoLoc Lifted)), ("a_8323_8325_8320", (2, TyVarFree NoLoc Lifted)), ("b_8323_8325_8321", (2, TyVarFree NoLoc Lifted)), ("a_8323_8325_8322", (2, TyVarFree NoLoc Unlifted)), ("t_8323_8325_8323", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8325_8324", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("a_8323_8326_8323", (2, TyVarFree NoLoc Unlifted)), ("b_8323_8326_8324", (2, TyVarFree NoLoc Unlifted)), ("c_8323_8326_8325", (2, TyVarFree NoLoc Unlifted)), ("t_8323_8327_8320", (3, TyVarFree NoLoc Lifted)), ("t_8323_8327_8321", (3, TyVarFree NoLoc Lifted)), ("t_8323_8327_8322", (3, TyVarFree NoLoc Lifted)), ("t_8323_8327_8323", (4, TyVarFree NoLoc Unlifted)), ("t_8323_8327_8324", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8327_8325", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8328_8320", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8328_8321", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8328_8322", (4, TyVarFree NoLoc Unlifted)), ("a_8323_8328_8323", (4, TyVarFree NoLoc Unlifted)), ("x_8323_8328_8324", (4, TyVarFree NoLoc Unlifted)), ("a_8323_8328_8325", (4, TyVarFree NoLoc Lifted)), ("b_8323_8328_8326", (4, TyVarFree NoLoc Lifted)), ("c_8323_8328_8327", (4, TyVarFree NoLoc Lifted)), ("a_8323_8328_8328", (4, TyVarFree NoLoc Lifted)), ("b_8323_8328_8329", (4, TyVarFree NoLoc Lifted)), ("ft_8323_8329_8320", (4, TyVarRecord NoLoc (M.fromList [("0", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8323_8328_8329" 14921}) []))]))), ("a_8323_8329_8329", (4, TyVarFree NoLoc Unlifted)), ("x_8324_8320_8320", (4, TyVarFree NoLoc Unlifted)), ("a_8324_8320_8321", (4, TyVarFree NoLoc Lifted)), ("b_8324_8320_8322", (4, TyVarFree NoLoc Lifted)), ("c_8324_8320_8323", (4, TyVarFree NoLoc Lifted)), ("a_8324_8320_8324", (4, TyVarFree NoLoc Lifted)), ("b_8324_8320_8325", (4, TyVarFree NoLoc Lifted)), ("ft_8324_8320_8326", (4, TyVarRecord NoLoc (M.fromList [("0", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8324_8320_8325" 14941}) []))]))), ("t_8324_8321_8329", (4, TyVarFree NoLoc Unlifted)), ("a_8324_8322_8320", (4, TyVarFree NoLoc Unlifted)), ("x_8324_8322_8321", (4, TyVarFree NoLoc Unlifted)), ("a_8324_8322_8322", (4, TyVarFree NoLoc Lifted)), ("b_8324_8322_8323", (4, TyVarFree NoLoc Lifted)), ("ft_8324_8322_8324", (4, TyVarRecord NoLoc (M.fromList [("1", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8324_8322_8323" 14967}) []))]))), ("a_8324_8322_8329", (4, TyVarFree NoLoc Unlifted)), ("x_8324_8323_8320", (4, TyVarFree NoLoc Unlifted)), ("a_8324_8323_8321", (4, TyVarFree NoLoc Lifted)), ("b_8324_8323_8322", (4, TyVarFree NoLoc Lifted)), ("ft_8324_8323_8323", (4, TyVarRecord NoLoc (M.fromList [("1", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8324_8323_8322" 14979}) []))]))), ("t_8324_8325_8320", (5, TyVarFree NoLoc Lifted)), ("a_8324_8325_8321", (6, TyVarFree NoLoc Unlifted)), ("b_8324_8325_8322", (6, TyVarFree NoLoc Unlifted)), ("x_8324_8325_8323", (6, TyVarFree NoLoc Unlifted)), ("t_8324_8325_8324", (7, TyVarFree NoLoc Lifted)), ("t_8324_8325_8325", (7, TyVarFree NoLoc Lifted)), ("t_8324_8325_8326", (7, TyVarFree NoLoc Lifted)), ("t_8324_8325_8327", (7, TyVarFree NoLoc Lifted)), ("t_8324_8325_8328", (8, TyVarFree NoLoc Lifted)), ("t_8324_8325_8320_8324_8326_8326", (5, TyVarFree NoLoc Lifted)), ("t_8323_8327_8320_8324_8326_8327", (3, TyVarFree NoLoc Lifted)), ("t_8323_8327_8322_8324_8326_8328", (3, TyVarFree NoLoc Lifted)), ("t_8323_8327_8321_8324_8326_8329", (3, TyVarFree NoLoc Lifted)), ("a_8323_8325_8320_8324_8327_8320", (2, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8326_8325_8325" ~ "[]t_8326_8325_8325_8326_8325_8327", + "t_8322_8320_8326" ~ "[]t_8322_8320_8326_8326_8325_8328", + "t_8326_8324_8326" ~ "[]t_8326_8324_8326_8326_8325_8329", + "if_t_8326_8324_8325" ~ "[]if_t_8326_8324_8325_8326_8326_8320", + "t_8324_8324_8326" ~ "[]t_8324_8324_8326_8326_8326_8321", + "t_8326_8324_8320" ~ "[]t_8326_8324_8320_8326_8326_8322", + "t_8324_8326_8325" ~ "[]t_8324_8326_8325_8326_8326_8323", + "t_8324_8326_8326" ~ "[]t_8324_8326_8326_8326_8326_8324", + "t_8326_8324_8321" ~ "[]t_8326_8324_8321_8326_8326_8325", + "t_8324_8326_8323" ~ "[]t_8324_8326_8323_8326_8326_8326", + "t_8327_8322" ~ "[]t_8327_8322_8326_8326_8327", + "t_8321_8325_8326" ~ "[]t_8321_8325_8326_8326_8326_8328", + "b_8321_8328_8321" ~ "[]b_8321_8328_8321_8326_8326_8329", + "a_8321_8328_8320" ~ "[]a_8321_8328_8320_8326_8327_8320", + "t_8321_8327_8329" ~ "[]t_8321_8327_8329_8326_8327_8321", + "t_8321_8327_8328" ~ "[]t_8321_8327_8328_8326_8327_8322", + "a_8321_8325_8327" ~ "[]a_8321_8325_8327_8326_8327_8323", + "t_8321_8321_8329" ~ "[]t_8321_8321_8329_8326_8327_8324", + "b_8327_8324" ~ "[]b_8327_8324_8326_8327_8325", + "a_8327_8323" ~ "[]a_8327_8323_8326_8327_8326", + "b_8327_8326" ~ "[]b_8327_8326_8326_8327_8327", + "a_8327_8325" ~ "[]a_8327_8325_8326_8327_8328", + "t_8321" ~ "num_8320", + "index_8322" ~ "index_elem_8323", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8323", + "index_8322" ~ "t_8325", + "t_8326" ~ "kt_8324", + "t_8327" ~ "t_8321_8320", + "num_8321_8321" ~ "t_8321_8320", + "t_8328" ~ "t_8321_8326", + "num_8321_8327" ~ "t_8321_8326", + "t_8321_8320" ~ "t_8329", + "t_8321_8326" ~ "t_8329", + "i32" ~ "t_8329", + "t_8322_8328" ~ "t_8322_8326", + "i32" ~ "t_8322_8329", + "num_8323_8320" ~ "t_8322_8329", + "t_8323_8326" ~ "t_8322_8326", + "t_8323_8327" ~ "t_8323_8325", + "index_8323_8328" ~ "index_elem_8323_8329", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8323_8329", + "index_8323_8328" ~ "t_8324_8321", + "t_8322_8329" ~ "kt_8324_8320", + "match_t_8324_8322" ~ "t_8322_8329", + "t_8324_8325" ~ "t_8324_8327", + "kt_8324_8326" ~ "t_8322_8326", + "t_8324_8325" ~ "t_8325_8321", + "match_t_8324_8322" ~ "t_8327", + "kt_8325_8320" ~ "t_8328", + "t_8325_8326" ~ "i32", + "t_8324_8325" ~ "t_8325_8328", + "kt_8325_8327" ~ "t_8322_8326", + "t_8324_8325" ~ "t_8326_8322", + "match_t_8324_8322" ~ "t_8327", + "kt_8326_8321" ~ "t_8328", + "t_8326_8327" ~ "i32", + "{n: t_8324_8325} -> {left: t_8325_8326, right: t_8326_8327}" ~ "a_8324_8323 -> x_8324_8324", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]a_8324_8323", + "t_8327_8322" ~ "[]x_8324_8324", + "t_8327_8329" ~ "t_8328_8322", + "num_8328_8323" ~ "t_8328_8322", + "bool" ~ "bool", + "t_8328_8320" ~ "t_8329_8320", + "num_8329_8321" ~ "t_8329_8320", + "bool" ~ "bool", + "i32" ~ "t_8328_8321", + "i32" ~ "t_8328_8321", + "{left: t_8327_8329, right: t_8328_8320} -> t_8328_8321" ~ "a_8327_8327 -> x_8327_8328", + "t_8327_8322" ~ "a_8327_8325", + "{as: []a_8327_8327} -> *[]x_8327_8328" ~ "a_8327_8325 -> b_8327_8326", + "t_8321_8320_8329 -> t_8321_8320_8329 -> t_8321_8320_8329" ~ "a_8321_8320_8328 -> a_8321_8320_8328 -> a_8321_8320_8328", + "num_8321_8321_8320" ~ "a_8321_8320_8328", + "b_8327_8326" ~ "a_8327_8323", + "{as: []a_8321_8320_8328} -> *[]a_8321_8320_8328" ~ "a_8327_8323 -> b_8327_8324", + "t_8321_8321_8329" ~ "b_8327_8324", + "t_8321_8321_8329" ~ "[]t_8321_8322_8321", + "t_8321_8322_8321" ~ "i32", + "i64" ~ "t_8321_8322_8320", + "num_8321_8322_8326" ~ "t_8321_8322_8320", + "t_8321_8323_8321" ~ "t_8321_8322_8320", + "t_8321_8323_8326" ~ "t_8321_8323_8324", + "num_8321_8323_8325" ~ "t_8321_8323_8324", + "t_8321_8324_8323" ~ "num_8321_8324_8322", + "num_8321_8324_8322" ~ "i64", + "t_8321_8321_8329" ~ "[]t_8321_8324_8321", + "t_8321_8323_8326 -> t_8321_8323_8324" ~ "a_8321_8323_8322 -> x_8321_8323_8323", + "[]t_8321_8324_8321" ~ "[]a_8321_8323_8322", + "t_8321_8325_8323" ~ "num_8321_8325_8322", + "[]x_8321_8323_8323" ~ "[]update_elem_8321_8325_8325", + "num_8321_8325_8324" ~ "update_elem_8321_8325_8325", + "t_8321_8325_8326" ~ "[]x_8321_8323_8323", + "t_8321_8326_8321" ~ "i32", + "t_8321_8325_8326" ~ "[]t_8321_8326_8325", + "{x: t_8321_8326_8321} -> (i64, num_8321_8326_8324)" ~ "a_8321_8325_8329 -> x_8321_8326_8320", + "[]t_8321_8326_8325" ~ "[]a_8321_8325_8329", + "[]x_8321_8326_8320" ~ "a_8321_8325_8327", + "{xs: [](a_8321_8327_8322, b_8321_8327_8323)} -> ([]a_8321_8327_8322, []b_8321_8327_8323)" ~ "a_8321_8325_8327 -> b_8321_8325_8328", + "(t_8321_8327_8328, t_8321_8327_8329)" ~ "b_8321_8325_8328", + "t_8321_8328_8323 -> t_8321_8328_8323 -> t_8321_8328_8323" ~ "a_8321_8328_8322 -> a_8321_8328_8322 -> a_8321_8328_8322", + "num_8321_8328_8324" ~ "a_8321_8328_8322", + "t_8321_8323_8321" ~ "i64", + "t_8321_8327_8328" ~ "[]i64", + "t_8321_8327_8329" ~ "[]a_8321_8328_8322", + "num_8321_8329_8327" ~ "i32", + "t_8321_8329_8326 -> t_8321_8329_8326 -> t_8321_8329_8326" ~ "a_8321_8329_8325 -> a_8321_8329_8325 -> a_8321_8329_8325", + "num_8321_8329_8327" ~ "a_8321_8329_8325", + "[]a_8321_8328_8322" ~ "a_8321_8328_8320", + "{as: []a_8321_8329_8325} -> *[]a_8321_8329_8325" ~ "a_8321_8328_8320 -> b_8321_8328_8321", + "t_8322_8320_8326" ~ "b_8321_8328_8321", + "t_8322_8321_8320" ~ "i64", + "t_8322_8321_8324" ~ "i32", + "t_8322_8321_8325" ~ "t_8322_8321_8321", + "index_8322_8321_8326" ~ "index_elem_8322_8321_8327", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8322_8321_8327", + "t_8322_8321_8328" ~ "index_8322_8321_8326", + "t_8322_8321_8329" ~ "t_8322_8321_8321", + "index_8322_8322_8320" ~ "index_elem_8322_8322_8321", + "t_8327_8322" ~ "[]index_elem_8322_8322_8321", + "t_8322_8322_8322" ~ "index_8322_8322_8320", + "t_8322_8322_8323" ~ "t_8322_8321_8321", + "index_8322_8322_8324" ~ "index_elem_8322_8322_8325", + "t_8321_8325_8326" ~ "[]index_elem_8322_8322_8325", + "t_8322_8322_8326" ~ "index_8322_8322_8324", + "t_8322_8321_8328" ~ "t_8322_8323_8320", + "kt_8322_8322_8329" ~ "t_8322_8322_8328", + "num_8322_8323_8321" ~ "t_8322_8322_8328", + "t_8326" ~ "t_8322_8323_8326", + "num_8322_8323_8327" ~ "t_8322_8323_8326", + "t_8322_8322_8328" ~ "t_8322_8322_8327", + "t_8322_8323_8326" ~ "t_8322_8322_8327", + "t_8322_8324_8326" ~ "t_8322_8322_8327", + "t_8322_8322_8322" ~ "t_8322_8324_8328", + "t_8322_8324_8329" ~ "kt_8322_8324_8327", + "t_8322_8324_8329" ~ "t_8322_8325_8321", + "num_8322_8325_8322" ~ "t_8322_8325_8321", + "t_8322_8322_8326" ~ "t_8322_8325_8327", + "t_8322_8321_8324" ~ "t_8322_8325_8327", + "bool" ~ "t_8322_8325_8320", + "bool" ~ "t_8322_8325_8320", + "t_8322_8326_8326" ~ "t_8322_8325_8320", + "t_8322_8321_8324" ~ "t_8322_8326_8327", + "num_8322_8326_8328" ~ "t_8322_8326_8327", + "t_8322_8322_8322" ~ "t_8322_8327_8326", + "t_8322_8326_8326" ~ "bool", + "t_8322_8324_8329" ~ "if_t_8322_8327_8327", + "kt_8322_8327_8325" ~ "if_t_8322_8327_8327", + "t_8322_8324_8326" ~ "t_8322_8327_8324", + "if_t_8322_8327_8327" ~ "t_8322_8327_8324", + "bool" ~ "bool", + "num_8322_8327_8323" ~ "if_t_8322_8328_8322", + "t_8322_8327_8324" ~ "if_t_8322_8328_8322", + "t_8322_8328_8323" ~ "if_t_8322_8328_8322", + "t_8322_8321_8328" ~ "t_8322_8328_8325", + "t_8322_8321_8328" ~ "t_8322_8328_8327", + "t_8322_8326_8326" ~ "bool", + "kt_8322_8328_8324" ~ "if_t_8322_8328_8328", + "kt_8322_8328_8326" ~ "if_t_8322_8328_8328", + "t_8322_8328_8329" ~ "if_t_8322_8328_8328", + "t_8322_8321_8324" ~ "t_8322_8329_8321", + "num_8322_8329_8322" ~ "t_8322_8329_8321", + "t_8322_8329_8328" ~ "t_8322_8328_8329", + "t_8323_8320_8320" ~ "t_8322_8328_8329", + "match_t_8323_8320_8321" ~ "bool", + "bool" ~ "t_8322_8329_8320", + "match_t_8323_8320_8321" ~ "t_8322_8329_8320", + "t_8323_8320_8326" ~ "t_8322_8329_8320", + "t_8323_8320_8329" ~ "t_8323_8320_8327", + "t_8323_8321_8321" ~ "t_8323_8320_8327", + "t_8323_8320_8328" ~ "t_8323_8321_8320", + "match_t_8323_8321_8322" ~ "t_8323_8320_8328", + "i32" ~ "match_t_8323_8321_8322", + "t_8322_8328_8329" ~ "t_8323_8320_8327", + "t_8323_8321_8325" ~ "i32", + "index_8323_8321_8326" ~ "index_elem_8323_8321_8327", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]index_elem_8323_8321_8327", + "t_8323_8321_8328" ~ "index_8323_8321_8326", + "t_8323_8321_8328" ~ "t_8323_8322_8320", + "t_8323_8321_8328" ~ "t_8323_8322_8322", + "kt_8323_8321_8329" ~ "f32", + "kt_8323_8322_8321" ~ "{x: f32, y: f32, z: f32}", + "t_8323_8321_8328" ~ "ft_8323_8322_8327", + "(num_8323_8322_8328, num_8323_8322_8329, num_8323_8323_8320)" ~ "(f32, f32, f32)", + "(num_8323_8323_8324, num_8323_8323_8325, num_8323_8323_8326)" ~ "(f32, f32, f32)", + "t_8323_8320_8326" ~ "bool", + "t_8323_8321_8328" ~ "if_t_8323_8323_8329", + "{mass: float_8323_8323_8323, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "if_t_8323_8323_8329", + "t_8323_8324_8320" ~ "if_t_8323_8323_8329", + "t_8322_8321_8324" ~ "t_8323_8324_8321", + "num_8323_8324_8322" ~ "t_8323_8324_8321", + "t_8323_8324_8328" ~ "num_8323_8324_8327", + "t_8323_8324_8329" ~ "t_8322_8321_8321", + "t_8323_8325_8321" ~ "num_8323_8325_8320", + "t_8323_8325_8322" ~ "num_8323_8325_8320", + "(t_8323_8325_8323, t_8323_8325_8324)" ~ "(t_8323_8325_8322, t_8323_8324_8329)", + "t_8323_8325_8327" ~ "num_8323_8325_8326", + "t_8323_8325_8323" ~ "t_8323_8325_8325", + "num_8323_8325_8326" ~ "t_8323_8325_8325", + "t_8323_8326_8322" ~ "t_8323_8325_8324", + "index_8323_8326_8323" ~ "index_elem_8323_8326_8324", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8323_8326_8324", + "index_8323_8326_8323" ~ "t_8323_8326_8326", + "t_8323_8326_8327" ~ "kt_8323_8326_8325", + "t_8323_8327_8320" ~ "num_8323_8326_8329", + "t_8323_8326_8327" ~ "t_8323_8326_8328", + "num_8323_8326_8329" ~ "t_8323_8326_8328", + "t_8323_8327_8326" ~ "t_8323_8326_8327", + "index_8323_8327_8327" ~ "index_elem_8323_8327_8328", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8323_8327_8328", + "index_8323_8327_8327" ~ "t_8323_8328_8320", + "kt_8323_8327_8329" ~ "t_8323_8320_8327", + "t_8323_8328_8323" ~ "i32", + "t_8323_8328_8324" ~ "t_8323_8326_8327", + "index_8323_8328_8325" ~ "index_elem_8323_8328_8326", + "t_8327_8322" ~ "[]index_elem_8323_8328_8326", + "t_8323_8328_8327" ~ "index_8323_8328_8325", + "t_8323_8328_8323" ~ "t_8323_8328_8328", + "t_8323_8325_8324" ~ "t_8323_8328_8328", + "t_8323_8329_8323" ~ "bool", + "t_8323_8328_8327" ~ "t_8323_8329_8325", + "t_8323_8328_8327" ~ "t_8323_8329_8327", + "t_8323_8329_8323" ~ "bool", + "kt_8323_8329_8324" ~ "if_t_8323_8329_8328", + "kt_8323_8329_8326" ~ "if_t_8323_8329_8328", + "t_8323_8329_8329" ~ "if_t_8323_8329_8328", + "t_8323_8329_8329" ~ "t_8324_8320_8320", + "num_8324_8320_8321" ~ "t_8324_8320_8320", + "t_8324_8320_8326" ~ "t_8323_8326_8327", + "index_8324_8320_8327" ~ "index_elem_8324_8320_8328", + "t_8321_8325_8326" ~ "[]index_elem_8324_8320_8328", + "t_8324_8320_8329" ~ "index_8324_8320_8327", + "t_8323_8328_8327" ~ "t_8324_8321_8325", + "kt_8324_8321_8324" ~ "t_8324_8321_8323", + "num_8324_8321_8326" ~ "t_8324_8321_8323", + "bool" ~ "bool", + "num_8324_8321_8322" ~ "t_8324_8321_8321", + "i32" ~ "t_8324_8321_8321", + "t_8324_8320_8329" ~ "t_8324_8321_8320", + "t_8324_8321_8321" ~ "t_8324_8321_8320", + "t_8323_8329_8323" ~ "bool", + "t_8324_8320_8329" ~ "if_t_8324_8323_8321", + "t_8324_8321_8320" ~ "if_t_8324_8323_8321", + "bool" ~ "bool", + "(if_t_8324_8323_8321, t_8323_8326_8327)" ~ "if_t_8324_8323_8322", + "(t_8323_8325_8323, t_8323_8326_8327)" ~ "if_t_8324_8323_8322", + "bool" ~ "bool", + "(num_8323_8327_8325, t_8323_8326_8327)" ~ "if_t_8324_8323_8323", + "if_t_8324_8323_8322" ~ "if_t_8324_8323_8323", + "(t_8323_8325_8322, t_8323_8324_8329)" ~ "if_t_8324_8323_8323", + "(t_8324_8323_8324, t_8324_8323_8325)" ~ "(t_8323_8325_8323, t_8323_8325_8324)", + "bool" ~ "bool", + "num_8323_8324_8327" ~ "if_t_8324_8323_8326", + "t_8324_8323_8324" ~ "if_t_8324_8323_8326", + "t_8324_8323_8327" ~ "if_t_8324_8323_8326", + "t_8324_8323_8329" ~ "num_8324_8323_8328", + "i32" ~ "num_8324_8323_8328", + "num_8324_8324_8321" ~ "i64", + "i32" ~ "t_8324_8324_8320", + "t_8324_8324_8326" ~ "[]t_8324_8324_8320", + "t_8322_8321_8324" ~ "t_8324_8324_8327", + "num_8324_8324_8328" ~ "t_8324_8324_8327", + "t_8322_8328_8329" ~ "t_8323_8320_8327", + "bool" ~ "bool", + "num_8324_8325_8323" ~ "if_t_8324_8325_8326", + "i32" ~ "if_t_8324_8325_8326", + "t_8324_8325_8327" ~ "if_t_8324_8325_8326", + "i32" ~ "num_8324_8325_8328", + "num_8324_8326_8320" ~ "et_8324_8325_8329", + "num_8324_8326_8321" ~ "et_8324_8325_8329", + "num_8324_8326_8322" ~ "et_8324_8325_8329", + "t_8324_8326_8323" ~ "[]et_8324_8325_8329", + "i32" ~ "num_8324_8326_8324", + "(t_8324_8326_8325, t_8324_8326_8326, t_8324_8326_8327, t_8324_8326_8328, t_8324_8326_8329)" ~ "(t_8324_8324_8326, t_8324_8326_8323, i32, t_8324_8325_8327, i32)", + "t_8324_8327_8322" ~ "num_8324_8327_8321", + "index_8324_8327_8323" ~ "index_elem_8324_8327_8324", + "t_8324_8326_8326" ~ "[]index_elem_8324_8327_8324", + "index_8324_8327_8323" ~ "t_8324_8327_8320", + "num_8324_8327_8325" ~ "t_8324_8327_8320", + "t_8324_8328_8321" ~ "t_8324_8326_8329", + "index_8324_8328_8322" ~ "index_elem_8324_8328_8323", + "t_8324_8326_8326" ~ "[]index_elem_8324_8328_8323", + "index_8324_8328_8322" ~ "t_8324_8328_8320", + "num_8324_8328_8324" ~ "t_8324_8328_8320", + "t_8324_8329_8320" ~ "t_8324_8326_8328", + "index_8324_8329_8321" ~ "index_elem_8324_8329_8322", + "t_8327_8322" ~ "[]index_elem_8324_8329_8322", + "index_8324_8329_8321" ~ "t_8324_8329_8324", + "kt_8324_8329_8323" ~ "t_8324_8328_8329", + "num_8324_8329_8325" ~ "t_8324_8328_8329", + "t_8325_8320_8320" ~ "t_8324_8326_8328", + "index_8325_8320_8321" ~ "index_elem_8325_8320_8322", + "t_8321_8325_8326" ~ "[]index_elem_8325_8320_8322", + "t_8325_8320_8323" ~ "index_8325_8320_8321", + "t_8325_8320_8324" ~ "t_8324_8326_8327", + "t_8324_8326_8325" ~ "[]update_elem_8325_8320_8325", + "t_8325_8320_8323" ~ "update_elem_8325_8320_8325", + "t_8324_8326_8327" ~ "t_8325_8320_8326", + "num_8325_8320_8327" ~ "t_8325_8320_8326", + "t_8325_8321_8322" ~ "t_8325_8320_8326", + "t_8325_8321_8323" ~ "t_8324_8326_8329", + "t_8324_8326_8326" ~ "[]update_elem_8325_8321_8325", + "num_8325_8321_8324" ~ "update_elem_8325_8321_8325", + "t_8325_8321_8326" ~ "t_8324_8326_8328", + "index_8325_8321_8327" ~ "index_elem_8325_8321_8328", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8325_8321_8328", + "index_8325_8321_8327" ~ "t_8325_8322_8320", + "kt_8325_8321_8329" ~ "t_8323_8320_8327", + "t_8325_8322_8323" ~ "i32", + "t_8324_8326_8329" ~ "t_8325_8322_8324", + "num_8325_8322_8325" ~ "t_8325_8322_8324", + "t_8325_8323_8320" ~ "t_8325_8322_8324", + "t_8325_8323_8321" ~ "t_8325_8323_8320", + "t_8324_8326_8326" ~ "[]update_elem_8325_8323_8323", + "num_8325_8323_8322" ~ "update_elem_8325_8323_8323", + "bool" ~ "bool", + "(t_8324_8326_8325, t_8324_8326_8326, t_8325_8321_8322, t_8324_8326_8328, t_8324_8326_8329)" ~ "if_t_8325_8323_8324", + "(t_8324_8326_8325, t_8324_8326_8326, t_8324_8326_8327, t_8325_8322_8323, t_8325_8323_8320)" ~ "if_t_8325_8323_8324", + "t_8325_8323_8326" ~ "t_8324_8326_8329", + "index_8325_8323_8327" ~ "index_elem_8325_8323_8328", + "t_8324_8326_8326" ~ "[]index_elem_8325_8323_8328", + "index_8325_8323_8327" ~ "t_8325_8323_8325", + "num_8325_8323_8329" ~ "t_8325_8323_8325", + "t_8325_8324_8325" ~ "t_8324_8326_8328", + "index_8325_8324_8326" ~ "index_elem_8325_8324_8327", + "t_8327_8322" ~ "[]index_elem_8325_8324_8327", + "index_8325_8324_8326" ~ "t_8325_8324_8329", + "kt_8325_8324_8328" ~ "t_8325_8324_8324", + "num_8325_8325_8320" ~ "t_8325_8324_8324", + "t_8325_8325_8325" ~ "t_8324_8326_8328", + "index_8325_8325_8326" ~ "index_elem_8325_8325_8327", + "t_8321_8325_8326" ~ "[]index_elem_8325_8325_8327", + "t_8325_8325_8328" ~ "index_8325_8325_8326", + "t_8325_8325_8329" ~ "t_8324_8326_8327", + "t_8325_8326_8323" ~ "t_8324_8326_8328", + "index_8325_8326_8324" ~ "index_elem_8325_8326_8325", + "t_8327_8322" ~ "[]index_elem_8325_8326_8325", + "index_8325_8326_8324" ~ "t_8325_8326_8327", + "kt_8325_8326_8326" ~ "t_8325_8326_8322", + "num_8325_8326_8328" ~ "t_8325_8326_8322", + "bool" ~ "bool", + "i32" ~ "t_8325_8326_8321", + "num_8325_8327_8325" ~ "t_8325_8326_8321", + "t_8325_8325_8328" ~ "t_8325_8326_8320", + "t_8325_8326_8321" ~ "t_8325_8326_8320", + "t_8324_8326_8325" ~ "[]update_elem_8325_8328_8324", + "t_8325_8326_8320" ~ "update_elem_8325_8328_8324", + "t_8324_8326_8327" ~ "t_8325_8328_8325", + "num_8325_8328_8326" ~ "t_8325_8328_8325", + "t_8325_8329_8321" ~ "t_8325_8328_8325", + "t_8325_8329_8322" ~ "t_8324_8326_8329", + "t_8324_8326_8326" ~ "[]update_elem_8325_8329_8324", + "num_8325_8329_8323" ~ "update_elem_8325_8329_8324", + "t_8325_8329_8325" ~ "t_8324_8326_8328", + "index_8325_8329_8326" ~ "index_elem_8325_8329_8327", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8325_8329_8327", + "index_8325_8329_8326" ~ "t_8325_8329_8329", + "kt_8325_8329_8328" ~ "t_8323_8320_8327", + "t_8326_8320_8322" ~ "i32", + "t_8324_8326_8329" ~ "t_8326_8320_8323", + "num_8326_8320_8324" ~ "t_8326_8320_8323", + "t_8326_8320_8329" ~ "t_8326_8320_8323", + "t_8326_8321_8320" ~ "t_8326_8320_8329", + "t_8324_8326_8326" ~ "[]update_elem_8326_8321_8322", + "num_8326_8321_8321" ~ "update_elem_8326_8321_8322", + "bool" ~ "bool", + "(t_8324_8326_8325, t_8324_8326_8326, t_8325_8329_8321, t_8324_8326_8328, t_8324_8326_8329)" ~ "if_t_8326_8321_8323", + "(t_8324_8326_8325, t_8324_8326_8326, t_8324_8326_8327, t_8326_8320_8322, t_8326_8320_8329)" ~ "if_t_8326_8321_8323", + "t_8326_8321_8324" ~ "t_8324_8326_8328", + "index_8326_8321_8325" ~ "index_elem_8326_8321_8326", + "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}" ~ "[]index_elem_8326_8321_8326", + "index_8326_8321_8325" ~ "t_8326_8321_8328", + "t_8326_8321_8329" ~ "kt_8326_8321_8327", + "t_8324_8326_8329" ~ "t_8326_8322_8320", + "num_8326_8322_8321" ~ "t_8326_8322_8320", + "t_8326_8322_8326" ~ "t_8326_8322_8320", + "t_8326_8322_8327" ~ "t_8326_8322_8326", + "t_8326_8322_8329" ~ "t_8326_8322_8326", + "index_8326_8323_8320" ~ "index_elem_8326_8323_8321", + "t_8324_8326_8326" ~ "[]index_elem_8326_8323_8321", + "num_8326_8323_8322" ~ "i32", + "index_8326_8323_8320" ~ "t_8326_8322_8328", + "num_8326_8323_8322" ~ "t_8326_8322_8328", + "t_8324_8326_8326" ~ "[]update_elem_8326_8323_8327", + "t_8326_8322_8328" ~ "update_elem_8326_8323_8327", + "bool" ~ "bool", + "if_t_8326_8321_8323" ~ "if_t_8326_8323_8328", + "(t_8324_8326_8325, t_8324_8326_8326, t_8324_8326_8327, t_8326_8321_8329, t_8326_8322_8326)" ~ "if_t_8326_8323_8328", + "bool" ~ "bool", + "if_t_8325_8323_8324" ~ "if_t_8326_8323_8329", + "if_t_8326_8323_8328" ~ "if_t_8326_8323_8329", + "(t_8324_8324_8326, t_8324_8326_8323, i32, t_8324_8325_8327, i32)" ~ "if_t_8326_8323_8329", + "(t_8326_8324_8320, t_8326_8324_8321, t_8326_8324_8322, t_8326_8324_8323, t_8326_8324_8324)" ~ "(t_8324_8326_8325, t_8324_8326_8326, t_8324_8326_8327, t_8324_8326_8328, t_8324_8326_8329)", + "t_8323_8320_8326" ~ "bool", + "t_8324_8324_8326" ~ "if_t_8326_8324_8325", + "t_8326_8324_8320" ~ "if_t_8326_8324_8325", + "t_8326_8324_8326" ~ "if_t_8326_8324_8325", + "t_8321_8323_8321" ~ "i64", + "{i: t_8322_8321_8320} -> {rp: t_8322_8321_8321} -> {body: t_8323_8324_8320, children: t_8326_8324_8326, is_leaf: t_8323_8320_8326, parent: t_8324_8323_8327, tree_level: t_8322_8328_8323}" ~ "a_8322_8320_8327 -> b_8322_8320_8328 -> x_8322_8320_8329", + "[]i64" ~ "[]a_8322_8320_8327", + "t_8322_8320_8326" ~ "[]b_8322_8320_8328", + "t_8326_8325_8325" ~ "[]x_8322_8320_8329", + "[]{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "t_8326_8325_8325" + ], + M.empty, + M.fromList [("num_8320", (3, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321", (3, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322", (3, TyVarFree NoLoc Unlifted)), ("index_elem_8323", (3, TyVarFree NoLoc Unlifted)), ("kt_8324", (3, TyVarFree NoLoc Lifted)), ("t_8325", (3, TyVarRecord NoLoc (M.fromList [("delta_node", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8324" 15206}) []))]))), ("t_8326", (4, TyVarFree NoLoc Lifted)), ("t_8327", (6, TyVarFree NoLoc Lifted)), ("t_8328", (7, TyVarFree NoLoc Lifted)), ("t_8329", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8321", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8327", (8, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8326", (6, TyVarFree NoLoc Lifted)), ("t_8322_8327", (7, TyVarFree NoLoc Lifted)), ("t_8322_8328", (7, TyVarSum NoLoc (M.fromList [("leaf", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8322_8327" 15232}) [])])]))), ("t_8322_8329", (7, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8320", (7, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8325", (7, TyVarFree NoLoc Lifted)), ("t_8323_8326", (7, TyVarSum NoLoc (M.fromList [("inner", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8323_8325" 15241}) [])])]))), ("t_8323_8327", (7, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8328", (7, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8329", (7, TyVarFree NoLoc Unlifted)), ("kt_8324_8320", (7, TyVarFree NoLoc Lifted)), ("t_8324_8321", (7, TyVarRecord NoLoc (M.fromList [("delta_node", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8324_8320" 15247}) []))]))), ("match_t_8324_8322", (7, TyVarFree NoLoc SizeLifted)), ("a_8324_8323", (5, TyVarFree NoLoc Unlifted)), ("x_8324_8324", (5, TyVarFree NoLoc Unlifted)), ("t_8324_8325", (6, TyVarFree NoLoc Lifted)), ("kt_8324_8326", (7, TyVarFree NoLoc Lifted)), ("t_8324_8327", (7, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8324_8326" 15255}) []))]))), ("kt_8325_8320", (7, TyVarFree NoLoc Lifted)), ("t_8325_8321", (7, TyVarRecord NoLoc (M.fromList [("delta_node", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8325_8320" 15260}) []))]))), ("t_8325_8326", (8, TyVarFree NoLoc Lifted)), ("kt_8325_8327", (9, TyVarFree NoLoc Lifted)), ("t_8325_8328", (9, TyVarRecord NoLoc (M.fromList [("right", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8325_8327" 15268}) []))]))), ("kt_8326_8321", (9, TyVarFree NoLoc Lifted)), ("t_8326_8322", (9, TyVarRecord NoLoc (M.fromList [("delta_node", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8326_8321" 15273}) []))]))), ("t_8326_8327", (10, TyVarFree NoLoc Lifted)), ("t_8327_8322", (6, TyVarFree NoLoc Lifted)), ("a_8327_8323", (7, TyVarFree NoLoc Lifted)), ("b_8327_8324", (7, TyVarFree NoLoc Lifted)), ("a_8327_8325", (7, TyVarFree NoLoc Lifted)), ("b_8327_8326", (7, TyVarFree NoLoc Lifted)), ("a_8327_8327", (7, TyVarFree NoLoc Unlifted)), ("x_8327_8328", (7, TyVarFree NoLoc Unlifted)), ("t_8327_8329", (8, TyVarFree NoLoc Lifted)), ("t_8328_8320", (8, TyVarFree NoLoc Lifted)), ("t_8328_8321", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8328_8322", (9, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8328_8323", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8320", (9, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8329_8321", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("a_8321_8320_8328", (7, TyVarFree NoLoc Unlifted)), ("t_8321_8320_8329", (7, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8321_8320", (7, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8321_8329", (8, TyVarFree NoLoc Lifted)), ("t_8321_8322_8320", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8321", (9, TyVarFree NoLoc Unlifted)), ("num_8321_8322_8326", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8321", (10, TyVarFree NoLoc Lifted)), ("a_8321_8323_8322", (11, TyVarFree NoLoc Unlifted)), ("x_8321_8323_8323", (11, TyVarFree NoLoc Unlifted)), ("t_8321_8323_8324", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8323_8325", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8326", (11, TyVarFree NoLoc Lifted)), ("t_8321_8324_8321", (11, TyVarFree NoLoc Unlifted)), ("num_8321_8324_8322", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8324_8323", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8325_8322", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8325_8323", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("num_8321_8325_8324", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8321_8325_8325", (11, TyVarFree NoLoc Unlifted)), ("t_8321_8325_8326", (12, TyVarFree NoLoc Lifted)), ("a_8321_8325_8327", (13, TyVarFree NoLoc Lifted)), ("b_8321_8325_8328", (13, TyVarFree NoLoc Lifted)), ("a_8321_8325_8329", (13, TyVarFree NoLoc Unlifted)), ("x_8321_8326_8320", (13, TyVarFree NoLoc Unlifted)), ("t_8321_8326_8321", (14, TyVarFree NoLoc Lifted)), ("num_8321_8326_8324", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8326_8325", (13, TyVarFree NoLoc Unlifted)), ("a_8321_8327_8322", (13, TyVarFree NoLoc Unlifted)), ("b_8321_8327_8323", (13, TyVarFree NoLoc Unlifted)), ("t_8321_8327_8328", (14, TyVarFree NoLoc Lifted)), ("t_8321_8327_8329", (14, TyVarFree NoLoc Lifted)), ("a_8321_8328_8320", (15, TyVarFree NoLoc Lifted)), ("b_8321_8328_8321", (15, TyVarFree NoLoc Lifted)), ("a_8321_8328_8322", (15, TyVarFree NoLoc Unlifted)), ("t_8321_8328_8323", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8328_8324", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("a_8321_8329_8325", (15, TyVarFree NoLoc Unlifted)), ("t_8321_8329_8326", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8329_8327", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8326", (16, TyVarFree NoLoc Lifted)), ("a_8322_8320_8327", (17, TyVarFree NoLoc Unlifted)), ("b_8322_8320_8328", (17, TyVarFree NoLoc Unlifted)), ("x_8322_8320_8329", (17, TyVarFree NoLoc Unlifted)), ("t_8322_8321_8320", (18, TyVarFree NoLoc Lifted)), ("t_8322_8321_8321", (19, TyVarFree NoLoc Lifted)), ("t_8322_8321_8324", (21, TyVarFree NoLoc Lifted)), ("t_8322_8321_8325", (22, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8321_8326", (22, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8321_8327", (22, TyVarFree NoLoc Unlifted)), ("t_8322_8321_8328", (23, TyVarFree NoLoc Lifted)), ("t_8322_8321_8329", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8322_8320", (24, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8322_8321", (24, TyVarFree NoLoc Unlifted)), ("t_8322_8322_8322", (25, TyVarFree NoLoc Lifted)), ("t_8322_8322_8323", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8322_8324", (26, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8322_8325", (26, TyVarFree NoLoc Unlifted)), ("t_8322_8322_8326", (27, TyVarFree NoLoc Lifted)), ("t_8322_8322_8327", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322_8328", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8322_8322_8329", (28, TyVarFree NoLoc Lifted)), ("t_8322_8323_8320", (28, TyVarRecord NoLoc (M.fromList [("delta_node", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8322_8329" 15494}) []))]))), ("num_8322_8323_8321", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8323_8326", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8323_8327", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8326", (29, TyVarFree NoLoc Lifted)), ("kt_8322_8324_8327", (30, TyVarFree NoLoc Lifted)), ("t_8322_8324_8328", (30, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8324_8327" 15515}) []))]))), ("t_8322_8324_8329", (31, TyVarFree NoLoc Lifted)), ("t_8322_8325_8320", (32, TyVarPrim NoLoc [Bool])), ("t_8322_8325_8321", (32, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8325_8322", (32, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8325_8327", (32, TyVarFree NoLoc Unlifted)), ("t_8322_8326_8326", (33, TyVarFree NoLoc Lifted)), ("t_8322_8326_8327", (34, TyVarFree NoLoc Unlifted)), ("num_8322_8326_8328", (34, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8327_8323", (34, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8327_8324", (34, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8322_8327_8325", (34, TyVarFree NoLoc Lifted)), ("t_8322_8327_8326", (34, TyVarRecord NoLoc (M.fromList [("right", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8327_8325" 15547}) []))]))), ("if_t_8322_8327_8327", (34, TyVarFree NoLoc SizeLifted)), ("if_t_8322_8328_8322", (34, TyVarFree NoLoc SizeLifted)), ("t_8322_8328_8323", (35, TyVarFree NoLoc Lifted)), ("kt_8322_8328_8324", (36, TyVarFree NoLoc Lifted)), ("t_8322_8328_8325", (36, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8328_8324" 15557}) []))]))), ("kt_8322_8328_8326", (36, TyVarFree NoLoc Lifted)), ("t_8322_8328_8327", (36, TyVarRecord NoLoc (M.fromList [("right", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8328_8326" 15559}) []))]))), ("if_t_8322_8328_8328", (36, TyVarFree NoLoc SizeLifted)), ("t_8322_8328_8329", (37, TyVarFree NoLoc Lifted)), ("t_8322_8329_8320", (38, TyVarPrim NoLoc [Bool])), ("t_8322_8329_8321", (38, TyVarFree NoLoc Unlifted)), ("num_8322_8329_8322", (38, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8329_8327", (38, TyVarFree NoLoc Lifted)), ("t_8322_8329_8328", (38, TyVarSum NoLoc (M.fromList [("leaf", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8322_8329_8327" 15571}) [])])]))), ("t_8322_8329_8329", (38, TyVarFree NoLoc Lifted)), ("t_8323_8320_8320", (38, TyVarSum NoLoc (M.fromList [("inner", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8322_8329_8329" 15573}) [])])]))), ("match_t_8323_8320_8321", (38, TyVarFree NoLoc SizeLifted)), ("t_8323_8320_8326", (39, TyVarFree NoLoc Lifted)), ("t_8323_8320_8327", (41, TyVarFree NoLoc Lifted)), ("t_8323_8320_8328", (42, TyVarFree NoLoc Lifted)), ("t_8323_8320_8329", (42, TyVarSum NoLoc (M.fromList [("inner", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8323_8320_8328" 15583}) [])])]))), ("t_8323_8321_8320", (42, TyVarFree NoLoc Lifted)), ("t_8323_8321_8321", (42, TyVarSum NoLoc (M.fromList [("leaf", [Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8323_8321_8320" 15585}) [])])]))), ("match_t_8323_8321_8322", (42, TyVarFree NoLoc SizeLifted)), ("t_8323_8321_8325", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8321_8326", (40, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8321_8327", (40, TyVarFree NoLoc Unlifted)), ("t_8323_8321_8328", (41, TyVarFree NoLoc Lifted)), ("kt_8323_8321_8329", (42, TyVarFree NoLoc Lifted)), ("t_8323_8322_8320", (42, TyVarRecord NoLoc (M.fromList [("mass", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8323_8321_8329" 15596}) []))]))), ("kt_8323_8322_8321", (42, TyVarFree NoLoc Lifted)), ("t_8323_8322_8322", (42, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8323_8322_8321" 15598}) []))]))), ("ft_8323_8322_8327", (42, TyVarRecord NoLoc (M.fromList [("position", Scalar (Record (M.fromList [("x", Scalar (Prim (FloatType Float32))), ("y", Scalar (Prim (FloatType Float32))), ("z", Scalar (Prim (FloatType Float32)))])))]))), ("num_8323_8322_8328", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8322_8329", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8320", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("float_8323_8323_8323", (40, TyVarPrim NoLoc [FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8324", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8325", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8323_8326", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8323_8323_8329", (40, TyVarFree NoLoc SizeLifted)), ("t_8323_8324_8320", (41, TyVarFree NoLoc Lifted)), ("t_8323_8324_8321", (42, TyVarFree NoLoc Unlifted)), ("num_8323_8324_8322", (42, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8324_8327", (42, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8324_8328", (42, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8324_8329", (43, TyVarFree NoLoc Lifted)), ("num_8323_8325_8320", (44, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8325_8321", (44, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8325_8322", (45, TyVarFree NoLoc Lifted)), ("t_8323_8325_8323", (46, TyVarFree NoLoc Lifted)), ("t_8323_8325_8324", (46, TyVarFree NoLoc Lifted)), ("t_8323_8325_8325", (46, TyVarFree NoLoc Unlifted)), ("num_8323_8325_8326", (46, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8325_8327", (46, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8326_8322", (46, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8326_8323", (46, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8326_8324", (46, TyVarFree NoLoc Unlifted)), ("kt_8323_8326_8325", (46, TyVarFree NoLoc Lifted)), ("t_8323_8326_8326", (46, TyVarRecord NoLoc (M.fromList [("parent", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8323_8326_8325" 15648}) []))]))), ("t_8323_8326_8327", (47, TyVarFree NoLoc Lifted)), ("t_8323_8326_8328", (48, TyVarFree NoLoc Unlifted)), ("num_8323_8326_8329", (48, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8327_8320", (48, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8327_8325", (48, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8327_8326", (48, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8327_8327", (48, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8327_8328", (48, TyVarFree NoLoc Unlifted)), ("kt_8323_8327_8329", (48, TyVarFree NoLoc Lifted)), ("t_8323_8328_8320", (48, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8323_8327_8329" 15664}) []))]))), ("t_8323_8328_8323", (49, TyVarFree NoLoc Lifted)), ("t_8323_8328_8324", (50, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8328_8325", (50, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8328_8326", (50, TyVarFree NoLoc Unlifted)), ("t_8323_8328_8327", (51, TyVarFree NoLoc Lifted)), ("t_8323_8328_8328", (52, TyVarFree NoLoc Unlifted)), ("t_8323_8329_8323", (53, TyVarFree NoLoc Lifted)), ("kt_8323_8329_8324", (54, TyVarFree NoLoc Lifted)), ("t_8323_8329_8325", (54, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8323_8329_8324" 15681}) []))]))), ("kt_8323_8329_8326", (54, TyVarFree NoLoc Lifted)), ("t_8323_8329_8327", (54, TyVarRecord NoLoc (M.fromList [("right", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8323_8329_8326" 15683}) []))]))), ("if_t_8323_8329_8328", (54, TyVarFree NoLoc SizeLifted)), ("t_8323_8329_8329", (55, TyVarFree NoLoc Lifted)), ("t_8324_8320_8320", (56, TyVarFree NoLoc Unlifted)), ("num_8324_8320_8321", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8320_8326", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8324_8320_8327", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8324_8320_8328", (56, TyVarFree NoLoc Unlifted)), ("t_8324_8320_8329", (57, TyVarFree NoLoc Lifted)), ("t_8324_8321_8320", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8321_8321", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8321_8322", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8321_8323", (58, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8324_8321_8324", (58, TyVarFree NoLoc Lifted)), ("t_8324_8321_8325", (58, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8324_8321_8324" 15702}) []))]))), ("num_8324_8321_8326", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8324_8323_8321", (58, TyVarFree NoLoc SizeLifted)), ("if_t_8324_8323_8322", (56, TyVarFree NoLoc SizeLifted)), ("if_t_8324_8323_8323", (48, TyVarFree NoLoc SizeLifted)), ("t_8324_8323_8324", (47, TyVarFree NoLoc Lifted)), ("t_8324_8323_8325", (47, TyVarFree NoLoc Lifted)), ("if_t_8324_8323_8326", (42, TyVarFree NoLoc SizeLifted)), ("t_8324_8323_8327", (43, TyVarFree NoLoc Lifted)), ("num_8324_8323_8328", (44, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8323_8329", (44, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8324_8320", (46, TyVarFree NoLoc Unlifted)), ("num_8324_8324_8321", (46, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8324_8326", (47, TyVarFree NoLoc Lifted)), ("t_8324_8324_8327", (48, TyVarFree NoLoc Unlifted)), ("num_8324_8324_8328", (48, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8325_8323", (48, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8324_8325_8326", (48, TyVarFree NoLoc SizeLifted)), ("t_8324_8325_8327", (49, TyVarFree NoLoc Lifted)), ("num_8324_8325_8328", (50, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("et_8324_8325_8329", (52, TyVarFree NoLoc Unlifted)), ("num_8324_8326_8320", (52, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8326_8321", (52, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8326_8322", (52, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8326_8323", (53, TyVarFree NoLoc Lifted)), ("num_8324_8326_8324", (54, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8326_8325", (56, TyVarFree NoLoc Lifted)), ("t_8324_8326_8326", (56, TyVarFree NoLoc Lifted)), ("t_8324_8326_8327", (56, TyVarFree NoLoc Lifted)), ("t_8324_8326_8328", (56, TyVarFree NoLoc Lifted)), ("t_8324_8326_8329", (56, TyVarFree NoLoc Lifted)), ("t_8324_8327_8320", (56, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8327_8321", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8327_8322", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8324_8327_8323", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8324_8327_8324", (56, TyVarFree NoLoc Unlifted)), ("num_8324_8327_8325", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8328_8320", (56, TyVarFree NoLoc Unlifted)), ("t_8324_8328_8321", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8324_8328_8322", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8324_8328_8323", (56, TyVarFree NoLoc Unlifted)), ("num_8324_8328_8324", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8328_8329", (56, TyVarFree NoLoc Unlifted)), ("t_8324_8329_8320", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8324_8329_8321", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8324_8329_8322", (56, TyVarFree NoLoc Unlifted)), ("kt_8324_8329_8323", (56, TyVarFree NoLoc Lifted)), ("t_8324_8329_8324", (56, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8324_8329_8323" 15791}) []))]))), ("num_8324_8329_8325", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8320_8320", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8320_8321", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8320_8322", (56, TyVarFree NoLoc Unlifted)), ("t_8325_8320_8323", (57, TyVarFree NoLoc Lifted)), ("t_8325_8320_8324", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("update_elem_8325_8320_8325", (58, TyVarFree NoLoc Unlifted)), ("t_8325_8320_8326", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8325_8320_8327", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8321_8322", (59, TyVarFree NoLoc Lifted)), ("t_8325_8321_8323", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("num_8325_8321_8324", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8325_8321_8325", (60, TyVarFree NoLoc Unlifted)), ("t_8325_8321_8326", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8321_8327", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8321_8328", (56, TyVarFree NoLoc Unlifted)), ("kt_8325_8321_8329", (56, TyVarFree NoLoc Lifted)), ("t_8325_8322_8320", (56, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8325_8321_8329" 15820}) []))]))), ("t_8325_8322_8323", (57, TyVarFree NoLoc Lifted)), ("t_8325_8322_8324", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8325_8322_8325", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8323_8320", (59, TyVarFree NoLoc Lifted)), ("t_8325_8323_8321", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("num_8325_8323_8322", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8325_8323_8323", (60, TyVarFree NoLoc Unlifted)), ("if_t_8325_8323_8324", (56, TyVarFree NoLoc SizeLifted)), ("t_8325_8323_8325", (56, TyVarFree NoLoc Unlifted)), ("t_8325_8323_8326", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8323_8327", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8323_8328", (56, TyVarFree NoLoc Unlifted)), ("num_8325_8323_8329", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8324_8324", (56, TyVarFree NoLoc Unlifted)), ("t_8325_8324_8325", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8324_8326", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8324_8327", (56, TyVarFree NoLoc Unlifted)), ("kt_8325_8324_8328", (56, TyVarFree NoLoc Lifted)), ("t_8325_8324_8329", (56, TyVarRecord NoLoc (M.fromList [("right", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8325_8324_8328" 15852}) []))]))), ("num_8325_8325_8320", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8325_8325", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8325_8326", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8325_8327", (56, TyVarFree NoLoc Unlifted)), ("t_8325_8325_8328", (57, TyVarFree NoLoc Lifted)), ("t_8325_8325_8329", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("t_8325_8326_8320", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8326_8321", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8326_8322", (58, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8326_8323", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8326_8324", (58, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8326_8325", (58, TyVarFree NoLoc Unlifted)), ("kt_8325_8326_8326", (58, TyVarFree NoLoc Lifted)), ("t_8325_8326_8327", (58, TyVarRecord NoLoc (M.fromList [("left", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8325_8326_8326" 15871}) []))]))), ("num_8325_8326_8328", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8325_8327_8325", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8325_8328_8324", (58, TyVarFree NoLoc Unlifted)), ("t_8325_8328_8325", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8325_8328_8326", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8329_8321", (59, TyVarFree NoLoc Lifted)), ("t_8325_8329_8322", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("num_8325_8329_8323", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8325_8329_8324", (60, TyVarFree NoLoc Unlifted)), ("t_8325_8329_8325", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8325_8329_8326", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8325_8329_8327", (56, TyVarFree NoLoc Unlifted)), ("kt_8325_8329_8328", (56, TyVarFree NoLoc Lifted)), ("t_8325_8329_8329", (56, TyVarRecord NoLoc (M.fromList [("right", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8325_8329_8328" 15909}) []))]))), ("t_8326_8320_8322", (57, TyVarFree NoLoc Lifted)), ("t_8326_8320_8323", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8326_8320_8324", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8326_8320_8329", (59, TyVarFree NoLoc Lifted)), ("t_8326_8321_8320", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("num_8326_8321_8321", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8326_8321_8322", (60, TyVarFree NoLoc Unlifted)), ("if_t_8326_8321_8323", (56, TyVarFree NoLoc SizeLifted)), ("t_8326_8321_8324", (56, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8326_8321_8325", (56, TyVarFree NoLoc Unlifted)), ("index_elem_8326_8321_8326", (56, TyVarFree NoLoc Unlifted)), ("kt_8326_8321_8327", (56, TyVarFree NoLoc Lifted)), ("t_8326_8321_8328", (56, TyVarRecord NoLoc (M.fromList [("parent", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8326_8321_8327" 15931}) []))]))), ("t_8326_8321_8329", (57, TyVarFree NoLoc Lifted)), ("t_8326_8322_8320", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8326_8322_8321", (58, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8326_8322_8326", (59, TyVarFree NoLoc Lifted)), ("t_8326_8322_8327", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("t_8326_8322_8328", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8326_8322_8329", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8326_8323_8320", (60, TyVarFree NoLoc Unlifted)), ("index_elem_8326_8323_8321", (60, TyVarFree NoLoc Unlifted)), ("num_8326_8323_8322", (60, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("update_elem_8326_8323_8327", (60, TyVarFree NoLoc Unlifted)), ("if_t_8326_8323_8328", (56, TyVarFree NoLoc SizeLifted)), ("if_t_8326_8323_8329", (56, TyVarFree NoLoc SizeLifted)), ("t_8326_8324_8320", (57, TyVarFree NoLoc Lifted)), ("t_8326_8324_8321", (57, TyVarFree NoLoc Lifted)), ("t_8326_8324_8322", (57, TyVarFree NoLoc Lifted)), ("t_8326_8324_8323", (57, TyVarFree NoLoc Lifted)), ("t_8326_8324_8324", (57, TyVarFree NoLoc Lifted)), ("if_t_8326_8324_8325", (48, TyVarFree NoLoc SizeLifted)), ("t_8326_8324_8326", (47, TyVarFree NoLoc Lifted)), ("t_8326_8325_8325", (18, TyVarFree NoLoc Lifted)), ("t_8326_8325_8325_8326_8325_8327", (18, TyVarFree NoLoc Lifted)), ("t_8322_8320_8326_8326_8325_8328", (16, TyVarFree NoLoc Lifted)), ("t_8326_8324_8326_8326_8325_8329", (47, TyVarFree NoLoc Lifted)), ("if_t_8326_8324_8325_8326_8326_8320", (48, TyVarFree NoLoc SizeLifted)), ("t_8324_8324_8326_8326_8326_8321", (47, TyVarFree NoLoc Lifted)), ("t_8326_8324_8320_8326_8326_8322", (57, TyVarFree NoLoc Lifted)), ("t_8324_8326_8325_8326_8326_8323", (56, TyVarFree NoLoc Lifted)), ("t_8324_8326_8326_8326_8326_8324", (56, TyVarFree NoLoc Lifted)), ("t_8326_8324_8321_8326_8326_8325", (57, TyVarFree NoLoc Lifted)), ("t_8324_8326_8323_8326_8326_8326", (53, TyVarFree NoLoc Lifted)), ("t_8327_8322_8326_8326_8327", (6, TyVarFree NoLoc Lifted)), ("t_8321_8325_8326_8326_8326_8328", (12, TyVarFree NoLoc Lifted)), ("b_8321_8328_8321_8326_8326_8329", (15, TyVarFree NoLoc Lifted)), ("a_8321_8328_8320_8326_8327_8320", (15, TyVarFree NoLoc Lifted)), ("t_8321_8327_8329_8326_8327_8321", (14, TyVarFree NoLoc Lifted)), ("t_8321_8327_8328_8326_8327_8322", (14, TyVarFree NoLoc Lifted)), ("a_8321_8325_8327_8326_8327_8323", (13, TyVarFree NoLoc Lifted)), ("t_8321_8321_8329_8326_8327_8324", (8, TyVarFree NoLoc Lifted)), ("b_8327_8324_8326_8327_8325", (7, TyVarFree NoLoc Lifted)), ("a_8327_8323_8326_8327_8326", (7, TyVarFree NoLoc Lifted)), ("b_8327_8326_8326_8327_8327", (7, TyVarFree NoLoc Lifted)), ("a_8327_8325_8326_8327_8328", (7, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8320" ~ "t_8323", + "num_8324" ~ "t_8323", + "t_8323" ~ "f32", + "num_8329" ~ "f32", + "f32" ~ "f32", + "num_8321_8324" ~ "f32", + "t_8321_8329" ~ "f32", + "t_8321" ~ "t_8322_8320", + "num_8322_8321" ~ "t_8322_8320", + "t_8322_8320" ~ "f32", + "num_8322_8326" ~ "f32", + "f32" ~ "f32", + "num_8323_8321" ~ "f32", + "t_8323_8326" ~ "f32", + "t_8322" ~ "t_8323_8327", + "num_8323_8328" ~ "t_8323_8327", + "t_8323_8327" ~ "f32", + "num_8324_8323" ~ "f32", + "f32" ~ "f32", + "num_8324_8328" ~ "f32", + "t_8325_8323" ~ "f32", + "t_8321_8329" ~ "f32", + "u32" ~ "u32", + "t_8325_8328" ~ "u32", + "t_8323_8326" ~ "f32", + "u32" ~ "u32", + "t_8326_8323" ~ "u32", + "t_8325_8323" ~ "f32", + "u32" ~ "u32", + "t_8326_8328" ~ "u32", + "t_8325_8328" ~ "t_8327_8321", + "num_8327_8322" ~ "t_8327_8321", + "t_8326_8323" ~ "t_8327_8327", + "num_8327_8328" ~ "t_8327_8327", + "t_8327_8321" ~ "t_8327_8320", + "t_8327_8327" ~ "t_8327_8320", + "t_8327_8320" ~ "t_8326_8329", + "t_8326_8328" ~ "t_8326_8329", + "u32" ~ "t_8326_8329" + ], + M.empty, + M.fromList [("t_8320", (1, TyVarFree NoLoc Lifted)), ("t_8321", (1, TyVarFree NoLoc Lifted)), ("t_8322", (1, TyVarFree NoLoc Lifted)), ("t_8323", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8329", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8324", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329", (3, TyVarFree NoLoc Lifted)), ("t_8322_8320", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8321", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8326", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8321", (4, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8326", (5, TyVarFree NoLoc Lifted)), ("t_8323_8327", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8328", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8323", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8324_8328", (6, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8323", (7, TyVarFree NoLoc Lifted)), ("t_8325_8328", (9, TyVarFree NoLoc Lifted)), ("t_8326_8323", (11, TyVarFree NoLoc Lifted)), ("t_8326_8328", (13, TyVarFree NoLoc Lifted)), ("t_8326_8329", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327_8320", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327_8321", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8327_8322", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327_8327", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8327_8328", (14, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64]))] + ), + ( [ "t_8323_8321_8325" ~ "[]t_8323_8321_8325_8323_8323_8326", + "t_8321_8325_8327" ~ "[]t_8321_8325_8327_8323_8323_8327", + "t_8322_8321_8328" ~ "[]t_8322_8321_8328_8323_8323_8328", + "t_8321_8327_8329" ~ "[]t_8321_8327_8329_8323_8323_8329", + "kt_8322_8326_8329" ~ "[]kt_8322_8326_8329_8323_8324_8320", + "kt_8322_8325_8324" ~ "[]kt_8322_8325_8324_8323_8324_8321", + "range_8322_8322_8327" ~ "[]range_8322_8322_8327_8323_8324_8322", + "t_8321_8327_8324" ~ "[]t_8321_8327_8324_8323_8324_8323", + "t_8321_8327_8321" ~ "[]t_8321_8327_8321_8323_8324_8324", + "ft_8324" ~ "ft_8326", + "a_8322" ~ "rt_8325", + "a_8322 -> b_8323" ~ "a_8320 -> x_8321", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]a_8320", + "[]x_8321" ~ "[]f32", + "t_8321_8323" ~ "f32", + "ft_8321_8328" ~ "ft_8322_8320", + "a_8321_8326" ~ "rt_8321_8329", + "a_8321_8326 -> b_8321_8327" ~ "a_8321_8324 -> x_8321_8325", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]a_8321_8324", + "[]x_8321_8325" ~ "[]f32", + "t_8322_8327" ~ "f32", + "ft_8323_8322" ~ "ft_8323_8324", + "a_8323_8320" ~ "rt_8323_8323", + "a_8323_8320 -> b_8323_8321" ~ "a_8322_8328 -> x_8322_8329", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]a_8322_8328", + "[]x_8322_8329" ~ "[]f32", + "t_8324_8321" ~ "f32", + "ft_8324_8326" ~ "ft_8324_8328", + "a_8324_8324" ~ "rt_8324_8327", + "a_8324_8324 -> b_8324_8325" ~ "a_8324_8322 -> x_8324_8323", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]a_8324_8322", + "[]x_8324_8323" ~ "[]f32", + "t_8325_8325" ~ "f32", + "ft_8326_8320" ~ "ft_8326_8322", + "a_8325_8328" ~ "rt_8326_8321", + "a_8325_8328 -> b_8325_8329" ~ "a_8325_8326 -> x_8325_8327", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]a_8325_8326", + "[]x_8325_8327" ~ "[]f32", + "t_8326_8329" ~ "f32", + "ft_8327_8324" ~ "ft_8327_8326", + "a_8327_8322" ~ "rt_8327_8325", + "a_8327_8322 -> b_8327_8323" ~ "a_8327_8320 -> x_8327_8321", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]a_8327_8320", + "[]x_8327_8321" ~ "[]f32", + "t_8328_8323" ~ "f32", + "t_8328_8324" ~ "t_8328_8328", + "t_8325_8325" ~ "t_8328_8328", + "t_8321_8323" ~ "t_8329_8323", + "t_8325_8325" ~ "t_8329_8323", + "t_8328_8328" ~ "t_8328_8327", + "t_8329_8323" ~ "t_8328_8327", + "t_8328_8325" ~ "t_8321_8320_8323", + "t_8326_8329" ~ "t_8321_8320_8323", + "t_8322_8327" ~ "t_8321_8320_8328", + "t_8326_8329" ~ "t_8321_8320_8328", + "t_8321_8320_8323" ~ "t_8321_8320_8322", + "t_8321_8320_8328" ~ "t_8321_8320_8322", + "t_8328_8326" ~ "t_8321_8321_8328", + "t_8328_8323" ~ "t_8321_8321_8328", + "t_8324_8321" ~ "t_8321_8322_8323", + "t_8328_8323" ~ "t_8321_8322_8323", + "t_8321_8321_8328" ~ "t_8321_8321_8327", + "t_8321_8322_8323" ~ "t_8321_8321_8327", + "{x: t_8328_8324, y: t_8328_8325, z: t_8328_8326} -> {x: t_8328_8327, y: t_8321_8320_8322, z: t_8321_8321_8327}" ~ "a_8321_8323_8322 -> b_8321_8323_8323", + "{x: f32, y: f32, z: f32} -> u32" ~ "b_8321_8323_8323 -> c_8321_8323_8324", + "t_8321_8323_8329" ~ "{x: a_8321_8323_8322} -> c_8321_8323_8324", + "t_8321_8324_8322" ~ "t_8321_8324_8324", + "t_8321_8323_8329" ~ "arg_8321_8324_8325 -> res_8321_8324_8326", + "kt_8321_8324_8323" ~ "arg_8321_8324_8325", + "{p: t_8321_8324_8322} -> res_8321_8324_8326" ~ "t_8321_8324_8320 -> k_8321_8324_8321", + "i32" ~ "i32", + "i32 -> u32 -> i32" ~ "i32 -> k_8321_8324_8321 -> i32", + "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "[]t_8321_8324_8320", + "t_8321_8325_8327" ~ "[]t_8321_8324_8320", + "t_8321_8326_8320" ~ "t_8321_8326_8322", + "t_8321_8323_8329" ~ "arg_8321_8326_8323 -> res_8321_8326_8324", + "kt_8321_8326_8321" ~ "arg_8321_8326_8323", + "{p: t_8321_8326_8320} -> res_8321_8326_8324" ~ "a_8321_8325_8328 -> x_8321_8325_8329", + "t_8321_8325_8327" ~ "[]a_8321_8325_8328", + "t_8321_8327_8321" ~ "[]x_8321_8325_8329", + "t_8321_8327_8321" ~ "[]u32", + "t_8321_8327_8324" ~ "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}", + "t_8321_8325_8327" ~ "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}", + "t_8321_8327_8324" ~ "[]{delta_node: i32, left: #inner i32 | #leaf i32, parent: i32, right: #inner i32 | #leaf i32, sfc_code: u32}", + "t_8321_8327_8329" ~ "[]{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}", + "t_8321_8328_8321" ~ "num_8321_8328_8320", + "index_8321_8328_8322" ~ "index_elem_8321_8328_8323", + "t_8321_8327_8324" ~ "[]index_elem_8321_8328_8323", + "index_8321_8328_8322" ~ "t_8321_8328_8325", + "t_8321_8328_8326" ~ "kt_8321_8328_8324", + "i32" ~ "t_8321_8328_8328", + "num_8321_8328_8329" ~ "t_8321_8328_8328", + "t_8321_8328_8328" ~ "t_8321_8328_8327", + "t_8321_8328_8326" ~ "t_8321_8328_8327", + "t_8321_8329_8328" ~ "t_8321_8328_8327", + "t_8321_8329_8328" ~ "t_8322_8320_8321", + "num_8322_8320_8322" ~ "t_8322_8320_8321", + "t_8322_8320_8321" ~ "t_8322_8320_8320", + "num_8322_8320_8327" ~ "t_8322_8320_8320", + "t_8322_8320_8320" ~ "t_8321_8329_8329", + "num_8322_8321_8322" ~ "t_8321_8329_8329", + "t_8322_8321_8327" ~ "t_8321_8329_8329", + "t_8322_8321_8328" ~ "t_8321_8327_8329", + "t_8322_8321_8329" ~ "t_8322_8321_8327", + "t_8322_8321_8327" ~ "t_8322_8322_8320", + "num_8322_8322_8321" ~ "t_8322_8322_8320", + "t_8322_8321_8327" ~ "t_8322_8322_8320", + "t_8322_8321_8327" ~ "num_8322_8322_8326", + "range_8322_8322_8327" ~ "[]t_8322_8321_8327", + "range_8322_8322_8327" ~ "[]elem_8322_8322_8328", + "t_8322_8322_8329" ~ "elem_8322_8322_8328", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "t_8322_8323_8324", + "kt_8322_8323_8323" ~ "t_8322_8323_8322", + "t_8322_8322_8329" ~ "t_8322_8323_8322", + "t_8322_8324_8320" ~ "num_8322_8323_8329", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "t_8322_8324_8322", + "t_8322_8324_8323" ~ "kt_8322_8324_8321", + "(t_8322_8324_8324, t_8322_8324_8325)" ~ "(t_8322_8324_8323, t_8322_8324_8320)", + "t_8322_8324_8325" ~ "t_8322_8324_8327", + "num_8322_8324_8328" ~ "t_8322_8324_8327", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "t_8322_8325_8325", + "t_8322_8325_8326" ~ "t_8322_8324_8325", + "index_8322_8325_8327" ~ "index_elem_8322_8325_8328", + "kt_8322_8325_8324" ~ "[]index_elem_8322_8325_8328", + "t_8322_8326_8320" ~ "num_8322_8325_8329", + "index_8322_8325_8327" ~ "t_8322_8325_8323", + "num_8322_8325_8329" ~ "t_8322_8325_8323", + "bool" ~ "t_8322_8324_8326", + "bool" ~ "t_8322_8324_8326", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "t_8322_8327_8320", + "t_8322_8327_8321" ~ "t_8322_8324_8325", + "index_8322_8327_8322" ~ "index_elem_8322_8327_8323", + "kt_8322_8326_8329" ~ "[]index_elem_8322_8327_8323", + "t_8322_8327_8324" ~ "index_8322_8327_8322", + "index_8322_8327_8325" ~ "index_elem_8322_8327_8326", + "t_8322_8321_8328" ~ "[]index_elem_8322_8327_8326", + "index_8322_8327_8325" ~ "t_8322_8327_8328", + "t_8322_8327_8329" ~ "kt_8322_8327_8327", + "t_8322_8324_8324" ~ "t_8322_8328_8321", + "t_8322_8327_8329" ~ "t_8322_8328_8323", + "kt_8322_8328_8320" ~ "{x: f32, y: f32, z: f32}", + "kt_8322_8328_8322" ~ "{x: f32, y: f32, z: f32}", + "t_8322_8328_8328" ~ "{x: f32, y: f32, z: f32}", + "t_8322_8324_8324" ~ "t_8322_8329_8321", + "t_8322_8327_8329" ~ "t_8322_8329_8323", + "kt_8322_8329_8320" ~ "t_8322_8328_8329", + "kt_8322_8329_8322" ~ "t_8322_8328_8329", + "t_8322_8329_8328" ~ "t_8322_8328_8329", + "t_8322_8324_8324" ~ "t_8323_8320_8320", + "t_8322_8324_8325" ~ "t_8323_8320_8321", + "num_8323_8320_8322" ~ "t_8323_8320_8321", + "(t_8322_8324_8323, t_8322_8324_8320)" ~ "({mass: t_8322_8329_8328, position: t_8322_8328_8328, velocity: kt_8322_8329_8329}, t_8323_8320_8321)", + "(t_8323_8320_8327, t_8323_8320_8328)" ~ "(t_8322_8324_8324, t_8322_8324_8325)", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "ft_8323_8320_8329", + "bool" ~ "bool", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "if_t_8323_8321_8320", + "{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "if_t_8323_8321_8320", + "{n: {body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}} -> if_t_8323_8321_8320" ~ "a_8322_8323_8320 -> x_8322_8323_8321", + "t_8322_8321_8328" ~ "[]a_8322_8323_8320", + "t_8321_8327_8329" ~ "[]x_8322_8323_8321", + "t_8323_8321_8325" ~ "t_8322_8321_8328", + "t_8321_8323" ~ "t_8323_8321_8327", + "t_8325_8325" ~ "t_8323_8321_8327", + "t_8323_8321_8327" ~ "et_8323_8321_8326", + "t_8322_8327" ~ "t_8323_8322_8322", + "t_8326_8329" ~ "t_8323_8322_8322", + "t_8323_8322_8322" ~ "et_8323_8321_8326", + "t_8324_8321" ~ "t_8323_8322_8327", + "t_8328_8323" ~ "t_8323_8322_8327", + "t_8323_8322_8327" ~ "et_8323_8321_8326", + "[]et_8323_8321_8326" ~ "[]f32", + "t_8323_8323_8324" ~ "f32", + "([]{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}, f32, i32, []{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32})}" ~ "(t_8323_8321_8325, t_8323_8323_8324, t_8321_8328_8326, t_8321_8325_8327)" + ], + M.empty, + M.fromList [("a_8320", (2, TyVarFree NoLoc Unlifted)), ("x_8321", (2, TyVarFree NoLoc Unlifted)), ("a_8322", (2, TyVarFree NoLoc Lifted)), ("b_8323", (2, TyVarFree NoLoc Lifted)), ("ft_8324", (2, TyVarFree NoLoc Lifted)), ("rt_8325", (2, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "ft_8324" 16341}) []))]))), ("ft_8326", (2, TyVarRecord NoLoc (M.fromList [("x", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8323" 16340}) []))]))), ("t_8321_8323", (3, TyVarFree NoLoc Lifted)), ("a_8321_8324", (4, TyVarFree NoLoc Unlifted)), ("x_8321_8325", (4, TyVarFree NoLoc Unlifted)), ("a_8321_8326", (4, TyVarFree NoLoc Lifted)), ("b_8321_8327", (4, TyVarFree NoLoc Lifted)), ("ft_8321_8328", (4, TyVarFree NoLoc Lifted)), ("rt_8321_8329", (4, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "ft_8321_8328" 16361}) []))]))), ("ft_8322_8320", (4, TyVarRecord NoLoc (M.fromList [("y", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8321_8327" 16360}) []))]))), ("t_8322_8327", (5, TyVarFree NoLoc Lifted)), ("a_8322_8328", (6, TyVarFree NoLoc Unlifted)), ("x_8322_8329", (6, TyVarFree NoLoc Unlifted)), ("a_8323_8320", (6, TyVarFree NoLoc Lifted)), ("b_8323_8321", (6, TyVarFree NoLoc Lifted)), ("ft_8323_8322", (6, TyVarFree NoLoc Lifted)), ("rt_8323_8323", (6, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "ft_8323_8322" 16381}) []))]))), ("ft_8323_8324", (6, TyVarRecord NoLoc (M.fromList [("z", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8323_8321" 16380}) []))]))), ("t_8324_8321", (7, TyVarFree NoLoc Lifted)), ("a_8324_8322", (8, TyVarFree NoLoc Unlifted)), ("x_8324_8323", (8, TyVarFree NoLoc Unlifted)), ("a_8324_8324", (8, TyVarFree NoLoc Lifted)), ("b_8324_8325", (8, TyVarFree NoLoc Lifted)), ("ft_8324_8326", (8, TyVarFree NoLoc Lifted)), ("rt_8324_8327", (8, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "ft_8324_8326" 16401}) []))]))), ("ft_8324_8328", (8, TyVarRecord NoLoc (M.fromList [("x", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8324_8325" 16400}) []))]))), ("t_8325_8325", (9, TyVarFree NoLoc Lifted)), ("a_8325_8326", (10, TyVarFree NoLoc Unlifted)), ("x_8325_8327", (10, TyVarFree NoLoc Unlifted)), ("a_8325_8328", (10, TyVarFree NoLoc Lifted)), ("b_8325_8329", (10, TyVarFree NoLoc Lifted)), ("ft_8326_8320", (10, TyVarFree NoLoc Lifted)), ("rt_8326_8321", (10, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "ft_8326_8320" 16421}) []))]))), ("ft_8326_8322", (10, TyVarRecord NoLoc (M.fromList [("y", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8325_8329" 16420}) []))]))), ("t_8326_8329", (11, TyVarFree NoLoc Lifted)), ("a_8327_8320", (12, TyVarFree NoLoc Unlifted)), ("x_8327_8321", (12, TyVarFree NoLoc Unlifted)), ("a_8327_8322", (12, TyVarFree NoLoc Lifted)), ("b_8327_8323", (12, TyVarFree NoLoc Lifted)), ("ft_8327_8324", (12, TyVarFree NoLoc Lifted)), ("rt_8327_8325", (12, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "ft_8327_8324" 16441}) []))]))), ("ft_8327_8326", (12, TyVarRecord NoLoc (M.fromList [("z", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "b_8327_8323" 16440}) []))]))), ("t_8328_8323", (13, TyVarFree NoLoc Lifted)), ("t_8328_8324", (15, TyVarFree NoLoc Lifted)), ("t_8328_8325", (15, TyVarFree NoLoc Lifted)), ("t_8328_8326", (15, TyVarFree NoLoc Lifted)), ("t_8328_8327", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8328_8328", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8323", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8322", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8323", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8328", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8321_8327", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8321_8328", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8322_8323", (16, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("a_8321_8323_8322", (14, TyVarFree NoLoc Lifted)), ("b_8321_8323_8323", (14, TyVarFree NoLoc Lifted)), ("c_8321_8323_8324", (14, TyVarFree NoLoc Lifted)), ("t_8321_8323_8329", (15, TyVarFree NoLoc Lifted)), ("t_8321_8324_8320", (16, TyVarFree NoLoc Unlifted)), ("k_8321_8324_8321", (16, TyVarFree NoLoc Unlifted)), ("t_8321_8324_8322", (17, TyVarFree NoLoc Lifted)), ("kt_8321_8324_8323", (18, TyVarFree NoLoc Lifted)), ("t_8321_8324_8324", (18, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8324_8323" 16528}) []))]))), ("arg_8321_8324_8325", (18, TyVarFree NoLoc Lifted)), ("res_8321_8324_8326", (18, TyVarFree NoLoc Lifted)), ("t_8321_8325_8327", (17, TyVarFree NoLoc Lifted)), ("a_8321_8325_8328", (18, TyVarFree NoLoc Unlifted)), ("x_8321_8325_8329", (18, TyVarFree NoLoc Unlifted)), ("t_8321_8326_8320", (19, TyVarFree NoLoc Lifted)), ("kt_8321_8326_8321", (20, TyVarFree NoLoc Lifted)), ("t_8321_8326_8322", (20, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8326_8321" 16552}) []))]))), ("arg_8321_8326_8323", (20, TyVarFree NoLoc Lifted)), ("res_8321_8326_8324", (20, TyVarFree NoLoc Lifted)), ("t_8321_8327_8321", (19, TyVarFree NoLoc Lifted)), ("t_8321_8327_8324", (21, TyVarFree NoLoc Lifted)), ("t_8321_8327_8329", (23, TyVarFree NoLoc Lifted)), ("num_8321_8328_8320", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8321", (24, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8328_8322", (24, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8328_8323", (24, TyVarFree NoLoc Unlifted)), ("kt_8321_8328_8324", (24, TyVarFree NoLoc Lifted)), ("t_8321_8328_8325", (24, TyVarRecord NoLoc (M.fromList [("delta_node", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8328_8324" 16585}) []))]))), ("t_8321_8328_8326", (25, TyVarFree NoLoc Lifted)), ("t_8321_8328_8327", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8328", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321_8328_8329", (26, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8328", (27, TyVarFree NoLoc Lifted)), ("t_8321_8329_8329", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8320", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8321", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8320_8322", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8320_8327", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8321_8322", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8321_8327", (29, TyVarFree NoLoc Lifted)), ("t_8322_8321_8328", (30, TyVarFree NoLoc Lifted)), ("t_8322_8321_8329", (30, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64])), ("t_8322_8322_8320", (30, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8322_8321", (30, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8322_8326", (30, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("range_8322_8322_8327", (30, TyVarFree NoLoc Unlifted)), ("elem_8322_8322_8328", (30, TyVarFree NoLoc Unlifted)), ("t_8322_8322_8329", (30, TyVarFree NoLoc Lifted)), ("a_8322_8323_8320", (30, TyVarFree NoLoc Unlifted)), ("x_8322_8323_8321", (30, TyVarFree NoLoc Unlifted)), ("t_8322_8323_8322", (32, TyVarFree NoLoc Unlifted)), ("kt_8322_8323_8323", (32, TyVarFree NoLoc Lifted)), ("t_8322_8323_8324", (32, TyVarRecord NoLoc (M.fromList [("tree_level", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8323_8323" 16643}) []))]))), ("num_8322_8323_8329", (32, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8320", (33, TyVarFree NoLoc Lifted)), ("kt_8322_8324_8321", (34, TyVarFree NoLoc Lifted)), ("t_8322_8324_8322", (34, TyVarRecord NoLoc (M.fromList [("body", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8324_8321" 16654}) []))]))), ("t_8322_8324_8323", (35, TyVarFree NoLoc Lifted)), ("t_8322_8324_8324", (36, TyVarFree NoLoc Lifted)), ("t_8322_8324_8325", (36, TyVarFree NoLoc Lifted)), ("t_8322_8324_8326", (36, TyVarPrim NoLoc [Bool])), ("t_8322_8324_8327", (36, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8324_8328", (36, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8325_8323", (36, TyVarFree NoLoc Unlifted)), ("kt_8322_8325_8324", (36, TyVarFree NoLoc Lifted)), ("t_8322_8325_8325", (36, TyVarRecord NoLoc (M.fromList [("children", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8325_8324" 16669}) []))]))), ("t_8322_8325_8326", (36, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8325_8327", (36, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8325_8328", (36, TyVarFree NoLoc Unlifted)), ("num_8322_8325_8329", (36, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8326_8320", (36, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8322_8326_8329", (36, TyVarFree NoLoc Lifted)), ("t_8322_8327_8320", (36, TyVarRecord NoLoc (M.fromList [("children", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8326_8329" 16687}) []))]))), ("t_8322_8327_8321", (36, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8327_8322", (36, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8327_8323", (36, TyVarFree NoLoc Unlifted)), ("t_8322_8327_8324", (36, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8327_8325", (36, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8327_8326", (36, TyVarFree NoLoc Unlifted)), ("kt_8322_8327_8327", (36, TyVarFree NoLoc Lifted)), ("t_8322_8327_8328", (36, TyVarRecord NoLoc (M.fromList [("body", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8327_8327" 16695}) []))]))), ("t_8322_8327_8329", (37, TyVarFree NoLoc Lifted)), ("kt_8322_8328_8320", (38, TyVarFree NoLoc Lifted)), ("t_8322_8328_8321", (38, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8328_8320" 16698}) []))]))), ("kt_8322_8328_8322", (38, TyVarFree NoLoc Lifted)), ("t_8322_8328_8323", (38, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8328_8322" 16700}) []))]))), ("t_8322_8328_8328", (39, TyVarFree NoLoc Lifted)), ("t_8322_8328_8329", (40, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8322_8329_8320", (40, TyVarFree NoLoc Lifted)), ("t_8322_8329_8321", (40, TyVarRecord NoLoc (M.fromList [("mass", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8329_8320" 16709}) []))]))), ("kt_8322_8329_8322", (40, TyVarFree NoLoc Lifted)), ("t_8322_8329_8323", (40, TyVarRecord NoLoc (M.fromList [("mass", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8329_8322" 16711}) []))]))), ("t_8322_8329_8328", (41, TyVarFree NoLoc Lifted)), ("kt_8322_8329_8329", (42, TyVarFree NoLoc Lifted)), ("t_8323_8320_8320", (42, TyVarRecord NoLoc (M.fromList [("velocity", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8329_8329" 16719}) []))]))), ("t_8323_8320_8321", (42, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8320_8322", (42, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8320_8327", (37, TyVarFree NoLoc Lifted)), ("t_8323_8320_8328", (37, TyVarFree NoLoc Lifted)), ("ft_8323_8320_8329", (38, TyVarRecord NoLoc (M.fromList [("body", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "t_8323_8320_8327" 16731}) []))]))), ("if_t_8323_8321_8320", (32, TyVarFree NoLoc SizeLifted)), ("t_8323_8321_8325", (27, TyVarFree NoLoc Lifted)), ("et_8323_8321_8326", (28, TyVarFree NoLoc Unlifted)), ("t_8323_8321_8327", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8322_8322", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8322_8327", (28, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8323_8324", (29, TyVarFree NoLoc Lifted)), ("t_8323_8321_8325_8323_8323_8326", (27, TyVarFree NoLoc Lifted)), ("t_8321_8325_8327_8323_8323_8327", (17, TyVarFree NoLoc Lifted)), ("t_8322_8321_8328_8323_8323_8328", (30, TyVarFree NoLoc Lifted)), ("t_8321_8327_8329_8323_8323_8329", (23, TyVarFree NoLoc Lifted)), ("kt_8322_8326_8329_8323_8324_8320", (36, TyVarFree NoLoc Lifted)), ("kt_8322_8325_8324_8323_8324_8321", (36, TyVarFree NoLoc Lifted)), ("range_8322_8322_8327_8323_8324_8322", (30, TyVarFree NoLoc Unlifted)), ("t_8321_8327_8324_8323_8324_8323", (21, TyVarFree NoLoc Lifted)), ("t_8321_8327_8321_8323_8324_8324", (19, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8322_8326_8324" ~ "[]t_8322_8326_8324_8323_8322_8320", + "a_8322_8326_8325" ~ "[]a_8322_8326_8325_8323_8322_8321", + "kt_8322_8326_8322" ~ "[]kt_8322_8326_8322_8323_8322_8322", + "kt_8329_8321" ~ "[]kt_8329_8321_8323_8322_8323", + "kt_8326_8328" ~ "[]kt_8326_8328_8323_8322_8324", + "kt_8324_8324" ~ "[]kt_8324_8324_8323_8322_8325", + "t_8322" ~ "num_8321", + "(t_8324, t_8325, t_8326, t_8327)" ~ "({x: f32, y: f32, z: f32}, num_8320, num_8321, num_8323)", + "t_8327" ~ "t_8328", + "num_8329" ~ "t_8328", + "t_8321_8327" ~ "num_8321_8326", + "t_8326" ~ "t_8321_8325", + "num_8321_8326" ~ "t_8321_8325", + "t_8322_8323" ~ "t_8326", + "index_8322_8324" ~ "index_elem_8322_8325", + "[]{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "[]index_elem_8322_8325", + "index_8322_8324" ~ "t_8322_8327", + "kt_8322_8326" ~ "t_8322_8322", + "t_8325" ~ "t_8322_8322", + "bool" ~ "t_8321_8324", + "bool" ~ "t_8321_8324", + "t_8323_8326" ~ "t_8321_8324", + "t_8323_8327" ~ "t_8325", + "index_8323_8328" ~ "index_elem_8323_8329", + "[]{body: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}, children: []i32, is_leaf: bool, parent: i32, tree_level: i32}" ~ "[]index_elem_8323_8329", + "t_8324_8320" ~ "index_8323_8328", + "t_8324_8322" ~ "num_8324_8321", + "t_8324_8320" ~ "t_8324_8325", + "t_8324_8326" ~ "t_8324_8322", + "index_8324_8327" ~ "index_elem_8324_8328", + "kt_8324_8324" ~ "[]index_elem_8324_8328", + "index_8324_8327" ~ "t_8324_8323", + "t_8326" ~ "t_8324_8323", + "t_8324_8322" ~ "t_8325_8323", + "num_8325_8324" ~ "t_8325_8323", + "num_8324_8321" ~ "t_8325_8323", + "t_8325_8329" ~ "t_8324_8322", + "t_8325_8329" ~ "t_8326_8321", + "num_8326_8322" ~ "t_8326_8321", + "t_8324_8320" ~ "t_8326_8329", + "t_8325_8329" ~ "t_8327_8320", + "num_8327_8321" ~ "t_8327_8320", + "t_8327_8326" ~ "t_8327_8320", + "index_8327_8327" ~ "index_elem_8327_8328", + "kt_8326_8328" ~ "[]index_elem_8327_8328", + "t_8328_8320" ~ "num_8327_8329", + "index_8327_8327" ~ "t_8326_8327", + "num_8327_8329" ~ "t_8326_8327", + "bool" ~ "t_8326_8320", + "bool" ~ "t_8326_8320", + "t_8324_8320" ~ "t_8329_8320", + "t_8324_8320" ~ "t_8329_8322", + "t_8325_8329" ~ "t_8329_8323", + "num_8329_8324" ~ "t_8329_8323", + "t_8329_8329" ~ "t_8329_8323", + "index_8321_8320_8320" ~ "index_elem_8321_8320_8321", + "kt_8329_8321" ~ "[]index_elem_8321_8320_8321", + "t_8326_8320" ~ "bool", + "kt_8328_8329" ~ "if_t_8321_8320_8322", + "index_8321_8320_8320" ~ "if_t_8321_8320_8322", + "t_8321_8320_8323" ~ "if_t_8321_8320_8322", + "t_8324_8320" ~ "t_8321_8320_8327", + "i32" ~ "t_8321_8320_8325", + "kt_8321_8320_8326" ~ "t_8321_8320_8325", + "t_8321_8320_8325" ~ "i32", + "f32" ~ "t_8321_8320_8324", + "f32" ~ "t_8321_8320_8324", + "t_8321_8321_8328" ~ "t_8321_8320_8324", + "t_8324_8320" ~ "t_8321_8322_8320", + "kt_8321_8321_8329" ~ "t_8321_8322_8322", + "t_8321_8322_8323" ~ "kt_8321_8322_8321", + "t_8324_8320" ~ "t_8321_8322_8325", + "kt_8321_8322_8324" ~ "t_8321_8322_8327", + "t_8321_8322_8328" ~ "kt_8321_8322_8326", + "f32" ~ "t_8321_8322_8329", + "t_8321_8322_8328" ~ "t_8321_8322_8329", + "t_8321_8322_8329" ~ "f32", + "t_8321_8322_8323" ~ "{x: f32, y: f32, z: f32}", + "t_8321_8323_8328" ~ "{x: f32, y: f32, z: f32}", + "{mass: f32, position: {x: f32, y: f32, z: f32}}" ~ "t_8321_8324_8320", + "t_8321_8324_8321" ~ "kt_8321_8323_8329", + "t_8321_8324_8321" ~ "t_8321_8324_8324", + "t_8321_8323_8328" ~ "t_8321_8324_8326", + "kt_8321_8324_8323" ~ "t_8321_8324_8322", + "kt_8321_8324_8325" ~ "t_8321_8324_8322", + "t_8321_8325_8321" ~ "t_8321_8324_8322", + "t_8321_8324_8321" ~ "t_8321_8325_8324", + "t_8321_8323_8328" ~ "t_8321_8325_8326", + "kt_8321_8325_8323" ~ "t_8321_8325_8322", + "kt_8321_8325_8325" ~ "t_8321_8325_8322", + "t_8321_8326_8321" ~ "t_8321_8325_8322", + "t_8321_8324_8321" ~ "t_8321_8326_8324", + "t_8321_8323_8328" ~ "t_8321_8326_8326", + "kt_8321_8326_8323" ~ "t_8321_8326_8322", + "kt_8321_8326_8325" ~ "t_8321_8326_8322", + "t_8321_8327_8321" ~ "t_8321_8326_8322", + "t_8321_8325_8321" ~ "t_8321_8327_8324", + "t_8321_8325_8321" ~ "t_8321_8327_8324", + "t_8321_8326_8321" ~ "t_8321_8327_8329", + "t_8321_8326_8321" ~ "t_8321_8327_8329", + "t_8321_8327_8324" ~ "t_8321_8327_8323", + "t_8321_8327_8329" ~ "t_8321_8327_8323", + "t_8321_8327_8321" ~ "t_8321_8328_8328", + "t_8321_8327_8321" ~ "t_8321_8328_8328", + "t_8321_8327_8323" ~ "t_8321_8327_8322", + "t_8321_8328_8328" ~ "t_8321_8327_8322", + "t_8321_8329_8327" ~ "t_8321_8327_8322", + "t_8321_8329_8327" ~ "f32", + "t_8321_8321_8328" ~ "t_8321_8329_8328", + "f32" ~ "t_8321_8329_8328", + "t_8322_8320_8325" ~ "t_8321_8329_8328", + "i32" ~ "t_8322_8320_8328", + "num_8322_8320_8329" ~ "t_8322_8320_8328", + "t_8324_8320" ~ "t_8322_8321_8325", + "t_8322_8320_8328" ~ "t_8322_8320_8327", + "kt_8322_8321_8324" ~ "t_8322_8320_8327", + "t_8322_8320_8327" ~ "t_8322_8320_8326", + "num_8322_8322_8320" ~ "t_8322_8320_8326", + "t_8322_8322_8325" ~ "bool", + "t_8324_8320" ~ "t_8322_8322_8329", + "kt_8322_8322_8328" ~ "t_8322_8322_8327", + "t_8322_8322_8325" ~ "t_8322_8322_8327", + "t_8322_8320_8325" ~ "t_8322_8323_8324", + "f32" ~ "t_8322_8323_8324", + "t_8322_8322_8327" ~ "t_8322_8322_8326", + "bool" ~ "t_8322_8322_8326", + "f32" ~ "f32", + "{mass: f32, position: {x: f32, y: f32, z: f32}}" ~ "{mass: f32, position: {x: f32, y: f32, z: f32}}", + "{mass: t_8321_8322_8328, position: t_8321_8323_8328}" ~ "{mass: f32, position: {x: f32, y: f32, z: f32}}", + "t_8322_8324_8329" ~ "{x: f32, y: f32, z: f32}", + "t_8324" ~ "{x: f32, y: f32, z: f32}", + "t_8322_8324_8329" ~ "{x: f32, y: f32, z: f32}", + "t_8324_8320" ~ "t_8322_8325_8325", + "t_8327" ~ "t_8322_8325_8326", + "num_8322_8325_8327" ~ "t_8322_8325_8326", + "t_8324_8320" ~ "t_8322_8326_8323", + "t_8322_8326_8324" ~ "kt_8322_8326_8322", + "t_8322_8327_8324" ~ "t_8322_8327_8322", + "num_8322_8327_8323" ~ "t_8322_8327_8322", + "bool -> i32" ~ "b_8322_8327_8320 -> c_8322_8327_8321", + "t_8322_8327_8324 -> bool" ~ "a_8322_8326_8329 -> b_8322_8327_8320", + "{x: a_8322_8326_8329} -> c_8322_8327_8321" ~ "a_8322_8326_8327 -> x_8322_8326_8328", + "t_8322_8326_8324" ~ "[]a_8322_8326_8327", + "t_8322_8328_8328 -> t_8322_8328_8328 -> t_8322_8328_8328" ~ "a_8322_8328_8327 -> a_8322_8328_8327 -> a_8322_8328_8327", + "num_8322_8328_8329" ~ "a_8322_8328_8327", + "[]x_8322_8326_8328" ~ "a_8322_8326_8325", + "{as: []a_8322_8328_8327} -> a_8322_8328_8327" ~ "a_8322_8326_8325 -> b_8322_8326_8326", + "t_8322_8329_8328" ~ "b_8322_8326_8326", + "t_8323_8320_8320" ~ "num_8322_8329_8329", + "index_8323_8320_8321" ~ "index_elem_8323_8320_8322", + "t_8322_8326_8324" ~ "[]index_elem_8323_8320_8322", + "t_8327" ~ "t_8323_8320_8324", + "t_8322_8329_8328" ~ "t_8323_8320_8324", + "t_8323_8320_8324" ~ "t_8323_8320_8323", + "num_8323_8320_8329" ~ "t_8323_8320_8323", + "t_8322_8322_8326" ~ "bool", + "({x: f32, y: f32, z: f32}, kt_8322_8325_8324, t_8325, t_8322_8325_8326)" ~ "if_t_8323_8321_8324", + "(t_8324, index_8323_8320_8321, t_8325, t_8323_8320_8323)" ~ "if_t_8323_8321_8324", + "t_8323_8326" ~ "bool", + "(t_8324, t_8321_8320_8323, t_8325, t_8327)" ~ "if_t_8323_8321_8325", + "if_t_8323_8321_8324" ~ "if_t_8323_8321_8325", + "({x: f32, y: f32, z: f32}, num_8320, num_8321, num_8323)" ~ "if_t_8323_8321_8325", + "(t_8323_8321_8326, t_8323_8321_8327, t_8323_8321_8328, t_8323_8321_8329)" ~ "(t_8324, t_8325, t_8326, t_8327)", + "{x: f32, y: f32, z: f32}" ~ "t_8323_8321_8326" + ], + M.empty, + M.fromList [("num_8320", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8321", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324", (5, TyVarFree NoLoc Lifted)), ("t_8325", (5, TyVarFree NoLoc Lifted)), ("t_8326", (5, TyVarFree NoLoc Lifted)), ("t_8327", (5, TyVarFree NoLoc Lifted)), ("t_8328", (5, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8329", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8324", (5, TyVarPrim NoLoc [Bool])), ("t_8321_8325", (5, TyVarFree NoLoc Unlifted)), ("num_8321_8326", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322", (5, TyVarFree NoLoc Unlifted)), ("t_8322_8323", (5, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8322_8324", (5, TyVarFree NoLoc Unlifted)), ("index_elem_8322_8325", (5, TyVarFree NoLoc Unlifted)), ("kt_8322_8326", (5, TyVarFree NoLoc Lifted)), ("t_8322_8327", (5, TyVarRecord NoLoc (M.fromList [("parent", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8326" 17047}) []))]))), ("t_8323_8326", (6, TyVarFree NoLoc Lifted)), ("t_8323_8327", (7, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8328", (7, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8329", (7, TyVarFree NoLoc Unlifted)), ("t_8324_8320", (8, TyVarFree NoLoc Lifted)), ("num_8324_8321", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8324_8322", (9, TyVarFree NoLoc Lifted)), ("t_8324_8323", (9, TyVarFree NoLoc Unlifted)), ("kt_8324_8324", (9, TyVarFree NoLoc Lifted)), ("t_8324_8325", (9, TyVarRecord NoLoc (M.fromList [("children", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8324_8324" 17068}) []))]))), ("t_8324_8326", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8324_8327", (9, TyVarFree NoLoc Unlifted)), ("index_elem_8324_8328", (9, TyVarFree NoLoc Unlifted)), ("t_8325_8323", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8325_8324", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8325_8329", (10, TyVarFree NoLoc Lifted)), ("t_8326_8320", (11, TyVarPrim NoLoc [Bool])), ("t_8326_8321", (11, TyVarFree NoLoc Unlifted)), ("num_8326_8322", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8326_8327", (11, TyVarFree NoLoc Unlifted)), ("kt_8326_8328", (11, TyVarFree NoLoc Lifted)), ("t_8326_8329", (11, TyVarRecord NoLoc (M.fromList [("children", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8326_8328" 17095}) []))]))), ("t_8327_8320", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8327_8321", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8327_8326", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8327_8327", (11, TyVarFree NoLoc Unlifted)), ("index_elem_8327_8328", (11, TyVarFree NoLoc Unlifted)), ("num_8327_8329", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8328_8320", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8328_8329", (11, TyVarFree NoLoc Lifted)), ("t_8329_8320", (11, TyVarRecord NoLoc (M.fromList [("parent", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8328_8329" 17119}) []))]))), ("kt_8329_8321", (11, TyVarFree NoLoc Lifted)), ("t_8329_8322", (11, TyVarRecord NoLoc (M.fromList [("children", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8329_8321" 17121}) []))]))), ("t_8329_8323", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8329_8324", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8329_8329", (11, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8321_8320_8320", (11, TyVarFree NoLoc Unlifted)), ("index_elem_8321_8320_8321", (11, TyVarFree NoLoc Unlifted)), ("if_t_8321_8320_8322", (11, TyVarFree NoLoc SizeLifted)), ("t_8321_8320_8323", (12, TyVarFree NoLoc Lifted)), ("t_8321_8320_8324", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8320_8325", (9, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8321_8320_8326", (9, TyVarFree NoLoc Lifted)), ("t_8321_8320_8327", (9, TyVarRecord NoLoc (M.fromList [("tree_level", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8320_8326" 17137}) []))]))), ("t_8321_8321_8328", (10, TyVarFree NoLoc Lifted)), ("kt_8321_8321_8329", (11, TyVarFree NoLoc Lifted)), ("t_8321_8322_8320", (11, TyVarRecord NoLoc (M.fromList [("body", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8321_8329" 17153}) []))]))), ("kt_8321_8322_8321", (11, TyVarFree NoLoc Lifted)), ("t_8321_8322_8322", (11, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8322_8321" 17155}) []))]))), ("t_8321_8322_8323", (12, TyVarFree NoLoc Lifted)), ("kt_8321_8322_8324", (13, TyVarFree NoLoc Lifted)), ("t_8321_8322_8325", (13, TyVarRecord NoLoc (M.fromList [("body", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8322_8324" 17158}) []))]))), ("kt_8321_8322_8326", (13, TyVarFree NoLoc Lifted)), ("t_8321_8322_8327", (13, TyVarRecord NoLoc (M.fromList [("mass", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8322_8326" 17160}) []))]))), ("t_8321_8322_8328", (14, TyVarFree NoLoc Lifted)), ("t_8321_8322_8329", (15, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8323_8328", (16, TyVarFree NoLoc Lifted)), ("kt_8321_8323_8329", (17, TyVarFree NoLoc Lifted)), ("t_8321_8324_8320", (17, TyVarRecord NoLoc (M.fromList [("position", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8323_8329" 17175}) []))]))), ("t_8321_8324_8321", (18, TyVarFree NoLoc Lifted)), ("t_8321_8324_8322", (19, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8321_8324_8323", (19, TyVarFree NoLoc Lifted)), ("t_8321_8324_8324", (19, TyVarRecord NoLoc (M.fromList [("x", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8324_8323" 17179}) []))]))), ("kt_8321_8324_8325", (19, TyVarFree NoLoc Lifted)), ("t_8321_8324_8326", (19, TyVarRecord NoLoc (M.fromList [("x", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8324_8325" 17181}) []))]))), ("t_8321_8325_8321", (20, TyVarFree NoLoc Lifted)), ("t_8321_8325_8322", (21, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8321_8325_8323", (21, TyVarFree NoLoc Lifted)), ("t_8321_8325_8324", (21, TyVarRecord NoLoc (M.fromList [("y", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8325_8323" 17190}) []))]))), ("kt_8321_8325_8325", (21, TyVarFree NoLoc Lifted)), ("t_8321_8325_8326", (21, TyVarRecord NoLoc (M.fromList [("y", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8325_8325" 17192}) []))]))), ("t_8321_8326_8321", (22, TyVarFree NoLoc Lifted)), ("t_8321_8326_8322", (23, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8321_8326_8323", (23, TyVarFree NoLoc Lifted)), ("t_8321_8326_8324", (23, TyVarRecord NoLoc (M.fromList [("z", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8326_8323" 17201}) []))]))), ("kt_8321_8326_8325", (23, TyVarFree NoLoc Lifted)), ("t_8321_8326_8326", (23, TyVarRecord NoLoc (M.fromList [("z", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8321_8326_8325" 17203}) []))]))), ("t_8321_8327_8321", (24, TyVarFree NoLoc Lifted)), ("t_8321_8327_8322", (25, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8323", (25, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8324", (25, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8327_8329", (25, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8328_8328", (25, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321_8329_8327", (20, TyVarFree NoLoc Lifted)), ("t_8321_8329_8328", (21, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8325", (22, TyVarFree NoLoc Lifted)), ("t_8322_8320_8326", (23, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8327", (23, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8320_8328", (23, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8320_8329", (23, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8322_8321_8324", (23, TyVarFree NoLoc Lifted)), ("t_8322_8321_8325", (23, TyVarRecord NoLoc (M.fromList [("tree_level", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8321_8324" 17264}) []))]))), ("num_8322_8322_8320", (23, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8322_8325", (24, TyVarFree NoLoc Lifted)), ("t_8322_8322_8326", (25, TyVarPrim NoLoc [Bool])), ("t_8322_8322_8327", (25, TyVarPrim NoLoc [Bool])), ("kt_8322_8322_8328", (25, TyVarFree NoLoc Lifted)), ("t_8322_8322_8329", (25, TyVarRecord NoLoc (M.fromList [("is_leaf", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8322_8328" 17280}) []))]))), ("t_8322_8323_8324", (25, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8324_8329", (26, TyVarFree NoLoc Lifted)), ("kt_8322_8325_8324", (27, TyVarFree NoLoc Lifted)), ("t_8322_8325_8325", (27, TyVarRecord NoLoc (M.fromList [("parent", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8325_8324" 17311}) []))]))), ("t_8322_8325_8326", (27, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8325_8327", (27, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("kt_8322_8326_8322", (25, TyVarFree NoLoc Lifted)), ("t_8322_8326_8323", (25, TyVarRecord NoLoc (M.fromList [("children", Scalar (TypeVar NoUniqueness (QualName {qualQuals = [], qualLeaf = VName "kt_8322_8326_8322" 17322}) []))]))), ("t_8322_8326_8324", (26, TyVarFree NoLoc Lifted)), ("a_8322_8326_8325", (27, TyVarFree NoLoc Lifted)), ("b_8322_8326_8326", (27, TyVarFree NoLoc Lifted)), ("a_8322_8326_8327", (27, TyVarFree NoLoc Unlifted)), ("x_8322_8326_8328", (27, TyVarFree NoLoc Unlifted)), ("a_8322_8326_8329", (27, TyVarFree NoLoc Lifted)), ("b_8322_8327_8320", (27, TyVarFree NoLoc Lifted)), ("c_8322_8327_8321", (27, TyVarFree NoLoc Lifted)), ("t_8322_8327_8322", (27, TyVarPrim NoLoc [Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8327_8323", (27, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8327_8324", (27, TyVarFree NoLoc Lifted)), ("a_8322_8328_8327", (27, TyVarFree NoLoc Unlifted)), ("t_8322_8328_8328", (27, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8322_8328_8329", (27, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8322_8329_8328", (28, TyVarFree NoLoc Lifted)), ("num_8322_8329_8329", (29, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8320_8320", (29, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), ("index_8323_8320_8321", (29, TyVarFree NoLoc Unlifted)), ("index_elem_8323_8320_8322", (29, TyVarFree NoLoc Unlifted)), ("t_8323_8320_8323", (29, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8323_8320_8324", (29, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("num_8323_8320_8329", (29, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), ("if_t_8323_8321_8324", (25, TyVarFree NoLoc SizeLifted)), ("if_t_8323_8321_8325", (9, TyVarFree NoLoc SizeLifted)), ("t_8323_8321_8326", (6, TyVarFree NoLoc Lifted)), ("t_8323_8321_8327", (6, TyVarFree NoLoc Lifted)), ("t_8323_8321_8328", (6, TyVarFree NoLoc Lifted)), ("t_8323_8321_8329", (6, TyVarFree NoLoc Lifted)), ("t_8322_8326_8324_8323_8322_8320", (26, TyVarFree NoLoc Lifted)), ("a_8322_8326_8325_8323_8322_8321", (27, TyVarFree NoLoc Lifted)), ("kt_8322_8326_8322_8323_8322_8322", (25, TyVarFree NoLoc Lifted)), ("kt_8329_8321_8323_8322_8323", (11, TyVarFree NoLoc Lifted)), ("kt_8326_8328_8323_8322_8324", (11, TyVarFree NoLoc Lifted)), ("kt_8324_8324_8323_8322_8325", (9, TyVarFree NoLoc Lifted))] + ), + ( [ "t_8326_8327" ~ "[]t_8326_8327_8327_8328", + "t_8326_8328" ~ "[]t_8326_8328_8327_8329", + "t_8326_8329" ~ "[]t_8326_8329_8328_8320", + "t_8326_8320" ~ "[]t_8326_8320_8328_8321", + "t_8327_8325" ~ "[]t_8327_8325_8328_8322", + "t_8327_8326" ~ "[]t_8327_8326_8328_8323", + "t_8327_8327" ~ "[]t_8327_8327_8328_8324", + "t_8326_8321" ~ "[]t_8326_8321_8328_8325", + "t_8325_8329" ~ "[]t_8325_8329_8328_8326", + "a_8324_8324" ~ "[]a_8324_8324_8328_8327", + "t_8324_8323" ~ "[]t_8324_8323_8328_8328", + "t_8323_8322" ~ "[]t_8323_8322_8328_8329", + "t_8321" ~ "float_8320", + "[]f32" ~ "[]a_8326", + "[]f32" ~ "[]b_8327", + "[]f32" ~ "[]c_8328", + "[]f32" ~ "[]a_8321_8325", + "[]f32" ~ "[]b_8321_8326", + "[]f32" ~ "[]c_8321_8327", + "(f32, f32, f32) -> {mass: f32} -> (f32, f32, f32) -> {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}" ~ "a_8322 -> b_8323 -> c_8324 -> x_8325", + "[](a_8326, b_8327, c_8328)" ~ "[]a_8322", + "[]f32" ~ "[]b_8323", + "[](a_8321_8325, b_8321_8326, c_8321_8327)" ~ "[]c_8324", + "t_8323_8322" ~ "[]x_8325", + "i32" ~ "i32", + "t_8321" ~ "f32", + "f32" ~ "f32", + "f32" ~ "f32", + "t_8323_8322" ~ "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}", + "t_8324_8323" ~ "[]{mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}", + "{b: {mass: f32, position: {x: f32, y: f32, z: f32}, velocity: {x: f32, y: f32, z: f32}}} -> ((f32, f32, f32), f32, (f32, f32, f32))" ~ "a_8324_8326 -> x_8324_8327", + "t_8324_8323" ~ "[]a_8324_8326", + "[]x_8324_8327" ~ "a_8324_8324", + "{xs: [](a_8325_8322, b_8325_8323, c_8325_8324)} -> ([]a_8325_8322, []b_8325_8323, []c_8325_8324)" ~ "a_8324_8324 -> b_8324_8325", + "(t_8325_8329, t_8326_8320, t_8326_8321)" ~ "b_8324_8325", + "t_8325_8329" ~ "[](a_8326_8322, b_8326_8323, c_8326_8324)", + "(t_8326_8327, t_8326_8328, t_8326_8329)" ~ "([]a_8326_8322, []b_8326_8323, []c_8326_8324)", + "t_8326_8321" ~ "[](a_8327_8320, b_8327_8321, c_8327_8322)", + "(t_8327_8325, t_8327_8326, t_8327_8327)" ~ "([]a_8327_8320, []b_8327_8321, []c_8327_8322)", + "([]f32, []f32, []f32, []f32, []f32, []f32, []f32)" ~ "(t_8326_8327, t_8326_8328, t_8326_8329, t_8326_8320, t_8327_8325, t_8327_8326, t_8327_8327)" + ], + M.empty, + M.fromList [("float_8320", (11, TyVarPrim NoLoc [FloatType Float16, FloatType Float32, FloatType Float64])), ("t_8321", (12, TyVarFree NoLoc Lifted)), ("a_8322", (13, TyVarFree NoLoc Unlifted)), ("b_8323", (13, TyVarFree NoLoc Unlifted)), ("c_8324", (13, TyVarFree NoLoc Unlifted)), ("x_8325", (13, TyVarFree NoLoc Unlifted)), ("a_8326", (13, TyVarFree NoLoc Unlifted)), ("b_8327", (13, TyVarFree NoLoc Unlifted)), ("c_8328", (13, TyVarFree NoLoc Unlifted)), ("a_8321_8325", (13, TyVarFree NoLoc Unlifted)), ("b_8321_8326", (13, TyVarFree NoLoc Unlifted)), ("c_8321_8327", (13, TyVarFree NoLoc Unlifted)), ("t_8323_8322", (14, TyVarFree NoLoc Lifted)), ("t_8324_8323", (16, TyVarFree NoLoc Lifted)), ("a_8324_8324", (17, TyVarFree NoLoc Lifted)), ("b_8324_8325", (17, TyVarFree NoLoc Lifted)), ("a_8324_8326", (17, TyVarFree NoLoc Unlifted)), ("x_8324_8327", (17, TyVarFree NoLoc Unlifted)), ("a_8325_8322", (17, TyVarFree NoLoc Unlifted)), ("b_8325_8323", (17, TyVarFree NoLoc Unlifted)), ("c_8325_8324", (17, TyVarFree NoLoc Unlifted)), ("t_8325_8329", (18, TyVarFree NoLoc Lifted)), ("t_8326_8320", (18, TyVarFree NoLoc Lifted)), ("t_8326_8321", (18, TyVarFree NoLoc Lifted)), ("a_8326_8322", (19, TyVarFree NoLoc Unlifted)), ("b_8326_8323", (19, TyVarFree NoLoc Unlifted)), ("c_8326_8324", (19, TyVarFree NoLoc Unlifted)), ("t_8326_8327", (20, TyVarFree NoLoc Lifted)), ("t_8326_8328", (20, TyVarFree NoLoc Lifted)), ("t_8326_8329", (20, TyVarFree NoLoc Lifted)), ("a_8327_8320", (21, TyVarFree NoLoc Unlifted)), ("b_8327_8321", (21, TyVarFree NoLoc Unlifted)), ("c_8327_8322", (21, TyVarFree NoLoc Unlifted)), ("t_8327_8325", (22, TyVarFree NoLoc Lifted)), ("t_8327_8326", (22, TyVarFree NoLoc Lifted)), ("t_8327_8327", (22, TyVarFree NoLoc Lifted)), ("t_8326_8327_8327_8328", (20, TyVarFree NoLoc Lifted)), ("t_8326_8328_8327_8329", (20, TyVarFree NoLoc Lifted)), ("t_8326_8329_8328_8320", (20, TyVarFree NoLoc Lifted)), ("t_8326_8320_8328_8321", (18, TyVarFree NoLoc Lifted)), ("t_8327_8325_8328_8322", (22, TyVarFree NoLoc Lifted)), ("t_8327_8326_8328_8323", (22, TyVarFree NoLoc Lifted)), ("t_8327_8327_8328_8324", (22, TyVarFree NoLoc Lifted)), ("t_8326_8321_8328_8325", (18, TyVarFree NoLoc Lifted)), ("t_8325_8329_8328_8326", (18, TyVarFree NoLoc Lifted)), ("a_8324_8324_8328_8327", (17, TyVarFree NoLoc Lifted)), ("t_8324_8323_8328_8328", (16, TyVarFree NoLoc Lifted)), ("t_8323_8322_8328_8329", (14, TyVarFree NoLoc Lifted))] + ) + ] diff --git a/src-testing/Language/Futhark/TypeChecker/TySolveBenchmarks.hs b/src-testing/Language/Futhark/TypeChecker/TySolveBenchmarks.hs index c8dbdcc4d6..6f477375d7 100644 --- a/src-testing/Language/Futhark/TypeChecker/TySolveBenchmarks.hs +++ b/src-testing/Language/Futhark/TypeChecker/TySolveBenchmarks.hs @@ -2,6 +2,8 @@ module Language.Futhark.TypeChecker.TySolveBenchmarks (benchmarks) where import Criterion (Benchmark, bench, bgroup, whnf) import Data.Map qualified as M +import Generated.AllFutBenchmarks +import Language.Futhark (qualName) import Language.Futhark.Syntax import Language.Futhark.SyntaxTests () import Language.Futhark.TypeChecker.Constraints @@ -13,7 +15,8 @@ import Language.Futhark.TypeChecker.Constraints TyVars, ) import Language.Futhark.TypeChecker.Monad (TypeError (..)) -import Language.Futhark.TypeChecker.TySolve (Solution, UnconTyVar, solve) +import Language.Futhark.TypeChecker.TySolve as N (Solution, UnconTyVar, solve) +import Language.Futhark.TypeChecker.TySolveOld as O (solve) (~) :: TypeBase () NoUniqueness -> TypeBase () NoUniqueness -> CtTy () t1 ~ t2 = CtEq (Reason mempty) t1 t2 @@ -21,23 +24,76 @@ t1 ~ t2 = CtEq (Reason mempty) t1 t2 tv :: VName -> Level -> (VName, (Level, TyVarInfo ())) tv v lvl = (v, (lvl, TyVarFree mempty Unlifted)) -solve' :: +solveNew :: ( [CtTy ()], TyParams, TyVars () ) -> Either TypeError ([UnconTyVar], Solution) -solve' (constraints, typarams, tyvars) = solve constraints typarams tyvars +solveNew (constraints, typarams, tyvars) = N.solve constraints typarams tyvars + +solveOld :: + ( [CtTy ()], + TyParams, + TyVars () + ) -> + Either TypeError ([UnconTyVar], Solution) +solveOld (constraints, typarams, tyvars) = O.solve constraints typarams tyvars + +generateContraints :: Int -> ([CtTy ()], TyParams, TyVars ()) +generateContraints num_vars + | num_vars <= 0 = + ([], mempty, mempty) + | num_vars == 1 = + let v0_name = VName (nameFromString "v_0") 0 + ty_vars = M.fromList [tv v0_name 0] + in ([], mempty, ty_vars) + | otherwise = + let var_names = + [ VName (nameFromString ("v_" ++ show i)) i + | i <- [0 .. num_vars - 1] + ] + + ty_vars = M.fromList $ map (`tv` 0) var_names + + mkTy :: VName -> TypeBase () NoUniqueness + mkTy v = Scalar (TypeVar NoUniqueness (qualName v) []) + + cts = + zipWith + (\v_i v_j -> mkTy v_i ~ mkTy v_j) + (init var_names) + (tail var_names) + ++ ["v_0" ~ "i32"] + + ty_params = mempty + in (cts, ty_params, ty_vars) benchmarks :: Benchmark benchmarks = bgroup "TySolve" - [ bench "trivial" $ - whnf - solve' - ( ["a_0" ~ "b_1"], - mempty, - M.fromList [tv "a_0" 0] + [ bgroup "Synthetic" $ + concatMap + ( \n -> + [ bench (show n ++ " variables (new)") $ + whnf solveNew (generateContraints n), + bench (show n ++ " variables (old)") $ + whnf solveOld (generateContraints n) + ] + ) + sizes, + bgroup "Converted" $ + concatMap + ( \(name, dataCase) -> + [ bench (name <> " (new)") $ whnf solveNew dataCase, + bench (name <> " (old)") $ whnf solveOld dataCase + ] ) + allFutBenchmarkCases ] + where + start = 100 + end = 1000 + i = 100 + sizes = [start, start + i .. end] diff --git a/src-testing/Language/Futhark/TypeChecker/TySolveTests.hs b/src-testing/Language/Futhark/TypeChecker/TySolveTests.hs index 944a2dd233..e154e32bad 100644 --- a/src-testing/Language/Futhark/TypeChecker/TySolveTests.hs +++ b/src-testing/Language/Futhark/TypeChecker/TySolveTests.hs @@ -13,10 +13,11 @@ import Language.Futhark.TypeChecker.Constraints TyVarInfo (..), TyVars, ) -import Language.Futhark.TypeChecker.Monad (prettyTypeError) +import Language.Futhark.TypeChecker.Monad (TypeError (TypeError), prettyTypeError) import Language.Futhark.TypeChecker.TySolve import Test.Tasty (TestTree, testGroup) -import Test.Tasty.HUnit (Assertion, assertFailure, testCase, (@?=)) +import Test.Tasty.HUnit (Assertion, assertBool, assertFailure, testCase, (@?=)) +import Text.Regex.TDFA ((=~)) testSolve :: [CtTy ()] -> @@ -29,6 +30,19 @@ testSolve constraints typarams tyvars expected = Right s -> s @?= expected Left e -> assertFailure $ docString $ prettyTypeError e +testSolveFail :: + [CtTy ()] -> + TyParams -> + TyVars () -> + String -> + Assertion +testSolveFail constraints typarams tyvars expected = + case solve constraints typarams tyvars of + Left (TypeError _ _ actualMsg) -> + let regexMatch :: Bool = docString actualMsg =~ expected + in assertBool "Regex doesn't match" regexMatch + Right _ -> assertFailure "Expected type error, but got a solution" + -- When writing type variables/names here (a_0, b_1), make *sure* that -- the numbers are distinct. These are all that actually matter for -- determining identity. @@ -36,24 +50,26 @@ testSolve constraints typarams tyvars expected = (~) :: TypeBase () NoUniqueness -> TypeBase () NoUniqueness -> CtTy () t1 ~ t2 = CtEq (Reason mempty) t1 t2 -tv :: VName -> Level -> (VName, (Level, TyVarInfo ())) -tv v lvl = (v, (lvl, TyVarFree mempty Unlifted)) +tvFree :: VName -> Level -> (VName, (Level, TyVarInfo ())) +tvFree v lvl = (v, (lvl, TyVarFree mempty Unlifted)) + +-- tvPrim :: VName -> Level -> [PrimType] -> (VName, (Level, TyVarInfo ())) +-- tvPrim v lvl types = (v, (lvl, TyVarPrim mempty types)) + +tvRecord :: VName -> Level -> M.Map Name (TypeBase () NoUniqueness) -> (VName, (Level, TyVarInfo ())) +tvRecord v lvl fields = (v, (lvl, TyVarRecord mempty fields)) + +-- tvSum :: VName -> Level -> M.Map Name [TypeBase () NoUniqueness] -> (VName, (Level, TyVarInfo ())) +-- tvSum v lvl fields = (v, (lvl, TyVarSum mempty fields)) + +typaram :: VName -> Level -> Liftedness -> (VName, (Level, Liftedness, Loc)) +typaram v lvl liftedness = (v, (lvl, liftedness, noLoc)) tests :: TestTree tests = testGroup "Unsized constraint solver" - [ testCase "empty" $ - testSolve [] mempty mempty ([], mempty), - -- - testCase "a_0 ~ b_1" $ - testSolve - ["a_0" ~ "b_1"] - mempty - (M.fromList [tv "a_0" 0]) - ([], M.fromList [("a_0", Right "b_1")]), - -- - testCase "infer unlifted" $ + [ testCase "infer unlifted" $ testSolve [ "t\8320_9896" ~ "if_t\8322_9898", "t\8321_9897" ~ "if_t\8322_9898", @@ -73,5 +89,327 @@ tests = ("t\8321_9897", Right "if_t\8322_9898"), ("t\8323_9899", Right "if_t\8322_9898") ] + ), + testCase "empty" $ + testSolve [] mempty mempty ([], mempty), + testCase "b_1 ~ a_0" $ + testSolve + ["b_1" ~ "a_0"] + mempty + (M.fromList [tvFree "b_1" 0]) + ([], M.fromList [("b_1", Right "a_0")]), + testCase "a_0 ~ b_1" $ + testSolve + ["a_0" ~ "b_1"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0]) + ([("b_1", Unlifted)], M.fromList [("a_0", Right "b_1")]), + testCase "multiple" $ + testSolve + ["b_1" ~ "a_0", "d_3" ~ "c_2", "e_4" ~ "c_2", "c_2" ~ "a_0"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0, tvFree "e_4" 0]) + ([("a_0", Unlifted)], M.fromList [("b_1", Right "a_0"), ("c_2", Right "a_0"), ("d_3", Right "a_0"), ("e_4", Right "a_0")]), + testCase "Two variables" $ + testSolve + ["a_0" ~ "b_1", "c_2" ~ "d_3"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "c_2" 0]) + ([], M.fromList [("a_0", Right "b_1"), ("c_2", Right "d_3")]), + testCase "i32 + (i32 + i32)" $ + testSolve + [ "i32 -> i32 -> a_0" ~ "i32 -> i32 -> i32", + "i32 -> a_0 -> b_1" ~ "i32 -> i32 -> i32" + ] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0]) + ([], M.fromList [("a_0", Right "i32"), ("b_1", Right "i32")]), + testCase "((λx -> λy -> x * y) i32) i32" $ + testSolve + [ "a_0 -> b_1 -> c_2" ~ "i32 -> i32 -> i32", + "a_0 -> b_1 -> c_2" ~ "i32 -> d_3", + "d_3" ~ "i32 -> e_4" + ] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0, tvFree "e_4" 0]) + ( [], + M.fromList + [ ("a_0", Right "i32"), + ("b_1", Right "i32"), + ("c_2", Right "i32"), + ("d_3", Right "i32 -> i32"), + ("e_4", Right "i32") + ] + ), + testCase "rec λf -> λn -> if n == 0 then 1 else n * (f (n - 1))" $ + testSolve + [ "b_1 -> i32 -> c_2" ~ "i32 -> i32 -> bool", + "b_1 -> i32 -> d_3" ~ "i32 -> i32 -> i32", + "a_0" ~ "d_3 -> e_4", + "b_1 -> e_4 -> f_5" ~ "i32 -> i32 -> i32", + "c_2" ~ "bool", + "i32" ~ "f_5", + "g_6 -> g_6" ~ "a_0 -> b_1 -> i32" + ] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0, tvFree "e_4" 0, tvFree "f_5" 0, tvFree "g_6" 0]) + ( [], + M.fromList + [ ("a_0", Right "i32 -> i32"), + ("b_1", Right "i32"), + ("c_2", Right "bool"), + ("d_3", Right "i32"), + ("e_4", Right "i32"), + ("f_5", Right "i32"), + ("g_6", Right "i32 -> i32") + ] + ), + testCase "let id = λx -> x in id id" $ + testSolve + ["b_1 -> b_1" ~ "(c_2 -> c_2) -> d_3"] + mempty + (M.fromList [tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0]) + ( [("c_2", Unlifted)], + M.fromList + [ ("b_1", Right "c_2 -> c_2"), + ("d_3", Right "c_2 -> c_2") + ] + ), + testCase "a_0 ~ i32" $ + testSolve + ["a_0" ~ "i32"] + mempty + (M.fromList [tvFree "a_0" 0]) + ([], M.fromList [("a_0", Right "i32")]), + testCase "a_0 ~ a_0" $ + testSolve + ["a_0" ~ "a_0"] + mempty + (M.fromList [tvFree "a_0" 0]) + ([("a_0", Unlifted)], mempty), + testCase "non-unifiable types" $ + testSolveFail + ["a_0" ~ "i32", "a_0" ~ "bool"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Cc]annot unify).?", + testCase "infinite type (function) 1" $ + testSolveFail + ["a_0" ~ "a_0 -> b_1"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Oo]ccurs check).?", + -- ! This case acts weird for the original implementation. + testCase "infinite type (function) 2" $ + testSolveFail + ["a_0" ~ "b_1 -> i32", "b_1" ~ "c_2", "b_1" ~ "d_3", "a_0" ~ "d_3"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0]) + ".?([Oo]ccurs check).?", + testCase "infinite type (list)" $ + testSolveFail + ["a_0" ~ "[]a_0"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Oo]ccurs check).?", + testCase "infinite type (tuple)" $ + testSolveFail + ["a_0" ~ "(a_0, bool)"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Oo]ccurs check).?", + testCase "infinite type (record) 1" $ + testSolveFail + ["a_0" ~ "{foo: a_0, bar: f32}"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Oo]ccurs check).?", + -- ! This case never finishes for the original implementation. + testCase "infinite type (record) 2" $ + testSolveFail + ["a_0" ~ "{foo: b_1}", "b_1" ~ "c_2", "a_0" ~ "c_2"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0]) + ".?([Oo]ccurs check).?", + testCase "infinite type (record) 3" $ + testSolveFail + ["a_0" ~ "{foo: b_1}", "c_2" ~ "b_1", "a_0" ~ "c_2"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0]) + ".?([Oo]ccurs check).?", + testCase "infinite type (record) 4" $ + testSolveFail + ["a_0" ~ "{foo: b_1}", "c_2" ~ "b_1", "d_3" ~ "c_2", "a_0" ~ "c_2"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0]) + ".?([Oo]ccurs check).?", + -- testCase "infinite type (sum type)" $ + -- testSolveFail + -- ["a_0" ~ "#foo a_0"] + -- mempty + -- (M.fromList [tvFree "a_0" 0]) + -- ".?([Oo]ccurs check).?", + + testCase "infinite type (consuming array param)" $ + testSolveFail + ["a_0" ~ "*[]a_0"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Oo]ccurs check).?", + -- ! This case acts weird for the original implementation. + testCase "infinite type (nested)" $ + testSolveFail + ["a_0" ~ "{foo: i32, bar: b_1}", "b_1" ~ "c_2", "c_2" ~ "i32 -> []a_0"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0]) + ".?([Oo]ccurs check).?", + testCase "vector and 2D matrix" $ + testSolveFail + ["a_0" ~ "[]i32", "a_0" ~ "[][]i32"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Cc]annot unify).?", + testCase "different array types" $ + testSolveFail + ["a_0" ~ "[]f64", "a_0" ~ "[]i64"] + mempty + (M.fromList [tvFree "a_0" 0]) + ".?([Cc]annot unify).?", + testCase "simple record" $ + testSolve + ["a_0" ~ "{foo: i32, bar: bool}"] + mempty + (M.fromList [tvFree "a_0" 0]) + ([], M.fromList [("a_0", Right "{foo: i32, bar: bool}")]), + testCase "record 2" $ + testSolve + ["a_0" ~ "{foo: b_1, bar: c_2}", "b_1" ~ "c_2", "c_2" ~ "i64"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0]) + ( [], + M.fromList + [ ("a_0", Right "{foo: i64, bar: i64}"), + ("b_1", Right "i64"), + ("c_2", Right "i64") + ] + ), + testCase "record 3" $ + testSolve + ["a_0" ~ "{foo: b_1, bar: c_2}", "b_1" ~ "c_2"] + (M.fromList [typaram "c_2" 0 Lifted]) + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0]) + ( [], + M.fromList + [ ("a_0", Right "{foo: c_2, bar: c_2}"), + ("b_1", Right "c_2") + ] + ), + testCase "tuple" $ + testSolve + ["a_0" ~ "(b_1, c_2, d_3)", "c_2" ~ "d_3"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0, tvFree "c_2" 0, tvFree "d_3" 0]) + ( [("b_1", Unlifted), ("d_3", Unlifted)], + M.fromList + [ ("a_0", Right "(b_1, d_3, d_3)"), + ("c_2", Right "d_3") + ] + ), + testCase "compatible levels" $ + testSolve + ["a_0" ~ "b_1"] + (M.fromList [typaram "a_0" 0 Unlifted]) + (M.fromList [tvFree "b_1" 1]) + ([], M.fromList [("b_1", Right "a_0")]), + testCase "scope violation 1" $ + testSolveFail + ["a_0" ~ "b_1"] + (M.fromList [typaram "b_1" 1 Unlifted]) + (M.fromList [tvFree "a_0" 0]) + ".?(scope violation).?", + testCase "scope violation 2" $ + testSolveFail + ["a_0" ~ "b_1", "b_1" ~ "c_2"] + (M.fromList [typaram "c_2" 1 Unlifted]) + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 1]) + ".?(scope violation).?", + testCase "differently sized tuples" $ + testSolveFail + ["a_0" ~ "(i32, c_2)", "b_1" ~ "(i32, c_2, bool)", "a_0" ~ "b_1"] + mempty + (M.fromList [tvFree "a_0" 0, tvFree "b_1" 0]) + ".?([Cc]annot unify).?", + testCase "Prim type last substitution" $ + testSolve + [ "t\8321_8321" ~ "num\8320_8320", + "index\8322_8322" ~ "index_elem\8323_8323", + "[]t_0" ~ "[]index_elem\8323_8323" + ] + (M.fromList [typaram "t_0" 0 Unlifted]) + ( M.fromList + [ ("num\8320_8320", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float16, FloatType Float32, FloatType Float64])), + ("t\8321_8321", (2, TyVarPrim NoLoc [Signed Int8, Signed Int16, Signed Int32, Signed Int64])), + ("index\8322_8322", (2, TyVarFree NoLoc Unlifted)), + ("index_elem\8323_8323", (2, TyVarFree NoLoc Unlifted)) + ] ) + ( [], + M.fromList + [ ("num\8320_8320", Left [Signed Int8, Signed Int16, Signed Int32, Signed Int64]), + ("t\8321_8321", Left [Signed Int8, Signed Int16, Signed Int32, Signed Int64]), + ("index\8322_8322", Right "t_0"), + ("index_elem\8323_8323", Right "t_0") + ] + ), + testCase "record with polymorphic fields" $ + testSolve + [ "d_3" ~ "{foo: e_4, bar: f_5}", + "e_4" ~ "i32", + "f64" ~ "f_5", + "a_0" ~ "d_3" + ] + mempty + ( M.fromList + [ tvRecord "a_0" 0 $ + M.fromList + [ ("foo", Scalar (Prim (Signed Int32))), + ("bar", Scalar (Prim (FloatType Float64))) + ], + tvFree "d_3" 0, + tvFree "e_4" 0, + tvFree "f_5" 0 + ] + ) + ( [], + M.fromList + [ ("a_0", Right "{foo: i32, bar: f64}"), + ("d_3", Right "{foo: i32, bar: f64}"), + ("e_4", Right "i32"), + ("f_5", Right "f64") + ] + ), + testCase "opaque type" $ + testSolveFail + ["a_0" ~ "i32"] + mempty + mempty + ".?([Cc]annot unify).?", + testCase "liftedness propagation (Lifted -> SizeLifted)" $ + testSolve + ["a_0" ~ "b_1"] + mempty + (M.fromList [("a_0", (0, TyVarFree mempty SizeLifted)), ("b_1", (0, TyVarFree mempty Lifted))]) + ([("b_1", SizeLifted)], M.fromList [("a_0", Right "b_1")]), + testCase "liftedness propagation (Lifted -> Unlifted)" $ + testSolve + ["a_0" ~ "b_1"] + mempty + (M.fromList [("a_0", (0, TyVarFree mempty Unlifted)), ("b_1", (0, TyVarFree mempty Lifted))]) + ([("b_1", Unlifted)], M.fromList [("a_0", Right "b_1")]), + testCase "liftedness propagation (SizeLifted -> Unlifted)" $ + testSolve + ["a_0" ~ "b_1"] + mempty + (M.fromList [("a_0", (0, TyVarFree mempty Unlifted)), ("b_1", (0, TyVarFree mempty SizeLifted))]) + ([("b_1", Unlifted)], M.fromList [("a_0", Right "b_1")]) ] diff --git a/src-testing/futhark_benchmarks.hs b/src-testing/futhark_benchmarks.hs index c2ad20b150..c4836ac3fe 100644 --- a/src-testing/futhark_benchmarks.hs +++ b/src-testing/futhark_benchmarks.hs @@ -2,9 +2,11 @@ module Main (main) where import Criterion.Main import Language.Futhark.ParserBenchmarks qualified +import Language.Futhark.TypeChecker.TySolveBenchmarks qualified main :: IO () main = defaultMain - [ Language.Futhark.ParserBenchmarks.benchmarks + [ Language.Futhark.ParserBenchmarks.benchmarks, + Language.Futhark.TypeChecker.TySolveBenchmarks.benchmarks ] diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 7664ce3aed..75ebead21d 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -1,4 +1,3 @@ --- | The constraint solver for unsized type equality constraints. module Language.Futhark.TypeChecker.TySolve ( Type, Solution, @@ -9,87 +8,35 @@ where import Control.Monad import Control.Monad.Except -import Control.Monad.State +import Control.Monad.Reader +import Control.Monad.ST import Data.Bifunctor import Data.List qualified as L import Data.Loc import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S -import Debug.Trace +import Debug.Trace (trace) import Futhark.Util (isEnvVarAtLeast) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.TypeChecker.Constraints import Language.Futhark.TypeChecker.Error import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote, prettyTypeError) -import Language.Futhark.TypeChecker.Types (substTyVars) +import Language.Futhark.TypeChecker.UnionFind -- | The type representation used by the constraint solver. Agnostic -- to sizes and uniqueness. type Type = CtType () --- | A (partial) solution for a type variable. -data TyVarSol - = -- | Has been substituted with this. - TyVarSol Type - | -- | Is an explicit (rigid) type parameter in the source program. - TyVarParam Level Liftedness Loc - | -- | Not substituted yet; has this constraint. - TyVarUnsol (TyVarInfo ()) - deriving (Show) - -newtype SolverState = SolverState - { -- | Left means linked to this other type variable. - solverTyVars :: M.Map TyVar (Either VName TyVarSol) - } - -initialState :: TyParams -> TyVars () -> SolverState -initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars - where - f (_lvl, info) = Right $ TyVarUnsol info - g (lvl, l, loc) = Right $ TyVarParam lvl l loc - -substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase () u) -substTyVar m v = - case M.lookup v m of - Just (Left v') -> substTyVar m v' - Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' - Just (Right TyVarParam {}) -> Nothing - Just (Right (TyVarUnsol {})) -> Nothing - Nothing -> Nothing - -maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) -maybeLookupTyVar orig = do - tyvars <- gets solverTyVars - let f v = case M.lookup v tyvars of - Nothing -> pure Nothing - Just (Left v') -> f v' - Just (Right info) -> pure $ Just info - f orig - -lookupTyVar :: TyVar -> SolveM (Either (TyVarInfo ()) Type) -lookupTyVar orig = - maybe bad unpack <$> maybeLookupTyVar orig - where - bad = error $ "Unknown tyvar: " <> prettyNameString orig - unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig - unpack (TyVarSol t) = Right t - unpack (TyVarUnsol info) = Left info - --- | Variable must be flexible. -lookupTyVarInfo :: TyVar -> SolveM (TyVarInfo ()) -lookupTyVarInfo v = do - r <- lookupTyVar v - case r of - Left info -> pure info - Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v +type UF s = M.Map TyVar (TyVarNode s) -setLink :: TyVar -> VName -> SolveM () -setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} +newtype SolverState s = SolverState {solverTyVars :: UF s} -setInfo :: TyVar -> TyVarSol -> SolveM () -setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} +newtype SolveM s a = SolveM + { runSolveM :: ExceptT TypeError (ReaderT (SolverState s) (ST s)) a + } + deriving (Functor, Applicative, Monad, MonadError TypeError, MonadReader (SolverState s)) -- | A solution maps a type variable to its substitution. This -- substitution is complete, in the sense there are no right-hand @@ -100,89 +47,36 @@ type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) -- a constraint on how it can be instantiated. type UnconTyVar = (VName, Liftedness) -typeVar :: (Monoid u) => VName -> TypeBase dim u -typeVar v = Scalar $ TypeVar mempty (qualName v) [] +liftST :: ST s a -> SolveM s a +liftST = SolveM . lift . lift -solution :: SolverState -> ([UnconTyVar], Solution) -solution s = - ( mapMaybe unconstrained $ M.toList $ solverTyVars s, - M.mapMaybe mkSubst $ solverTyVars s - ) - where - mkSubst (Right (TyVarSol t)) = - Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t - mkSubst (Left v') = - Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ - mkSubst =<< M.lookup v' (solverTyVars s) - mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts - mkSubst _ = Nothing - - unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) - unconstrained _ = Nothing - -newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} - deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) - --- Try to substitute as much information as we have. -enrichType :: Type -> SolveM Type -enrichType t = do - s <- get - pure $ substTyVars (substTyVar (solverTyVars s)) t - -typeError :: Loc -> Notes -> Doc () -> SolveM () -typeError loc notes msg = - throwError $ TypeError loc notes msg +getSol' :: TyVarNode s -> SolveM s TyVarSol +getSol' = liftST . getSol -occursCheck :: Reason Type -> VName -> Type -> SolveM () -occursCheck reason v tp = do - vars <- gets solverTyVars - let tp' = substTyVars (substTyVar vars) tp - when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ - "Occurs check: cannot instantiate" - <+> prettyName v - <+> "with" - <+> pretty tp - <> "." +union' :: TyVarNode s -> TyVarNode s -> SolveM s () +union' tv1 tv2 = liftST $ union tv1 tv2 -unifySharedConstructors :: - Reason Type -> - BreadCrumbs -> - M.Map Name [Type] -> - M.Map Name [Type] -> - SolveM () -unifySharedConstructors reason bcs cs1 cs2 = - forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> - if length ts1 == length ts2 - then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 - else - typeError (locOf reason) mempty $ - "Cannot unify type with constructor" - indent 2 (pretty (Sum (M.singleton c ts1))) - "with type of constructor" - indent 2 (pretty (Sum (M.singleton c ts2))) - "because they differ in arity." +unionNewSol' tv1 tv2 new_sol = liftST $ unionNewSol tv1 tv2 new_sol -unifySharedFields :: - Reason Type -> - BreadCrumbs -> - M.Map Name Type -> - M.Map Name Type -> - SolveM () -unifySharedFields reason bcs fs1 fs2 = - forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> - solveEq reason (matchingField f <> bcs) ts1 ts2 +unionNewSol' :: TyVarNode s -> TyVarNode s -> TyVarSol -> SolveM s () +getKey' :: TyVarNode s -> SolveM s TyVar +getKey' = liftST . getKey -scopeViolation :: Reason Type -> VName -> Type -> VName -> SolveM () -scopeViolation reason v1 ty v2 = - typeError (locOf reason) mempty $ - "Cannot unify type" - indent 2 (pretty ty) - "with" - <+> dquotes (prettyName v1) - <+> "(scope violation)." - "This is because" - <+> dquotes (prettyName v2) - <+> "is rigidly bound in a deeper scope." +initializeState :: TyParams -> TyVars () -> ST s (SolverState s) +initializeState typarams tyvars = do + tyvars' <- M.traverseWithKey f tyvars + typarams' <- M.traverseWithKey g typarams + pure $ SolverState $ typarams' <> tyvars' + where + f tv (_lvl, info) = makeTyVarNode tv info + g tv (lvl, lft, loc) = makeTyParamNode tv lvl lft loc + +typeError :: Loc -> Notes -> Doc () -> SolveM s () +typeError loc notes msg = + throwError $ TypeError loc notes msg + +typeVar :: (Monoid u) => VName -> TypeBase dim u +typeVar v = Scalar $ TypeVar mempty (qualName v) [] cannotUnify :: Reason Type -> @@ -190,10 +84,10 @@ cannotUnify :: BreadCrumbs -> Type -> Type -> - SolveM () + SolveM s () cannotUnify reason notes bcs t1 t2 = do - t1' <- enrichType t1 - t2' <- enrichType t2 + t1' <- substTyVars t1 + t2' <- substTyVars t2 case reason of Reason loc -> typeError loc notes . stack $ @@ -219,8 +113,8 @@ cannotUnify reason notes bcs t1 t2 = do ] <> [pretty bcs | not $ hasNoBreadCrumbs bcs] ReasonRetType loc expected actual -> do - expected' <- enrichType expected - actual' <- enrichType actual + expected' <- substTyVars expected + actual' <- substTyVars actual typeError loc notes . stack $ [ "Function body does not have expected type.", "Expected:" <+> align (pretty expected'), @@ -228,8 +122,8 @@ cannotUnify reason notes bcs t1 t2 = do ] <> [pretty bcs | not $ hasNoBreadCrumbs bcs] ReasonApply loc f e expected actual -> do - expected' <- enrichType expected - actual' <- enrichType actual + expected' <- substTyVars expected + actual' <- substTyVars actual typeError loc notes . stack $ [ header, "Expected:" <+> align (pretty expected'), @@ -271,211 +165,207 @@ cannotUnify reason notes bcs t1 t2 = do where fname' = maybe "expression" (dquotes . pretty) fname ReasonBranches loc former latter -> do - former' <- enrichType former - latter' <- enrichType latter + former' <- substTyVars former + latter' <- substTyVars latter typeError loc notes . stack $ [ "Branches differ in type.", "Former:" <+> pretty former', "Latter:" <+> pretty latter' ] --- Precondition: 'v' is currently flexible. -subTyVar :: Reason Type -> BreadCrumbs -> VName -> Type -> SolveM () -subTyVar reason bcs v t = do - occursCheck reason v t - v_info <- gets $ M.lookup v . solverTyVars +unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a +unsharedConstructorsMsg cs1 cs2 = + "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." + where + missing = + filter (`notElem` M.keys cs1) (M.keys cs2) + ++ filter (`notElem` M.keys cs2) (M.keys cs1) + +substTyVars :: (Monoid u) => TypeBase () u -> SolveM s (TypeBase () u) +substTyVars (Scalar (TypeVar u qn args)) = do + mb_node <- maybeLookupUF $ qualLeaf qn + case mb_node of + Just node -> do + sol <- getSol' node + qn_k <- qualName <$> getKey' node + case sol of + Solved t -> do + t' <- substTyVars t + pure $ second (const mempty) t' + _ -> makeTyVar qn_k + _ -> makeTyVar qn + where + makeTyVar qn' = do + args' <- mapM onArg args + pure $ Scalar $ TypeVar u qn' args' + onArg (TypeArgType t) = TypeArgType <$> substTyVars t + onArg d@(TypeArgDim _) = pure d +substTyVars p@(Scalar (Prim _)) = pure p +substTyVars (Scalar (Record fs)) = + Scalar . Record <$> traverse substTyVars fs +substTyVars (Scalar (Sum cs)) = + Scalar . Sum <$> traverse (mapM substTyVars) cs +substTyVars (Scalar (Arrow u pname d t1 (RetType ext t2))) = do + t1' <- substTyVars t1 + t2' <- substTyVars t2 + pure $ + Scalar $ + Arrow u pname d t1' $ + RetType ext $ + t2' `setUniqueness` uniqueness t2 +substTyVars (Array u shape elemt) = do + elemt' <- substTyVars $ Scalar elemt + pure $ arrayOfWithAliases u shape elemt' + +occursCheck :: Reason Type -> VName -> VName -> Type -> SolveM s () +occursCheck reason v k tp = do + let vars = typeVars tp + when (k `S.member` vars) . typeError (locOf reason) mempty $ + "Occurs check: cannot instantiate" + <+> prettyName v + <+> "with" + <+> pretty tp + <> "." + +bindTyVar :: + Reason Type -> + BreadCrumbs -> + VName -> + TyVarNode s -> + Type -> + SolveM s () +bindTyVar reason bcs v v_node t' = do + t <- substTyVars t' + k <- getKey' v_node + occursCheck reason v k t - -- Set a solution for v, then update info for t in case v has any - -- odd constraints. - setInfo v (TyVarSol t) + v_info <- getSol' v_node + + setInfo v_node $ Solved t case (v_info, t) of - (Just (Right (TyVarUnsol TyVarFree {})), _) -> - pure () - ( Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), - _ - ) -> - if t `elem` map (Scalar . Prim) v_pts - then pure () - else cannotUnify reason notes bcs (typeVar v) t - where - notes = - aNote $ - "Cannot instance type that must be one of" - indent 2 (pretty v_pts) - "with" - indent 2 (pretty t) - ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), - Scalar (Sum cs2) - ) -> - if all (`elem` M.keys cs2) (M.keys cs1) - then unifySharedConstructors reason bcs cs1 cs2 - else cannotUnify reason notes bcs (typeVar v) t - where - notes = - aNote $ - "Cannot match type with constructors" - indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) - "with type with constructors" - indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) - unsharedConstructorsMsg cs1 cs2 - ( Just (Right (TyVarUnsol (TyVarSum _ cs1))), - _ - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type with constructors" - indent 2 (pretty (Sum cs1)) - "with type" - indent 2 (pretty t) - ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), - Scalar (Record fs2) - ) -> - if all (`elem` M.keys fs2) (M.keys fs1) - then unifySharedFields reason bcs fs1 fs2 - else - typeError (locOf reason) mempty $ - "Cannot unify record type with fields" - indent 2 (pretty (Record fs1)) - "with record type" - indent 2 (pretty (Record fs2)) - ( Just (Right (TyVarUnsol (TyVarRecord _ fs1))), - _ - ) -> - typeError (locOf reason) mempty $ - "Cannot unify record type with fields" - indent 2 (pretty (Record fs1)) - "with type" - indent 2 (pretty t) + (Unsolved TyVarFree {}, _) -> pure () + (Unsolved (TyVarPrim _ v_pts), _) -> + if t `elem` map (Scalar . Prim) v_pts + then pure () + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot instantiate type that must be one of" + indent 2 (pretty v_pts) + "with" + indent 2 (pretty t) + (Unsolved (TyVarSum _ cs1), Scalar (Sum cs2)) -> + if all (`elem` M.keys cs2) (M.keys cs1) + then unifySharedConstructors reason bcs cs1 cs2 + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot match type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) + "with type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) + unsharedConstructorsMsg cs1 cs2 + (Unsolved (TyVarSum _ cs1), _) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty t) + (Unsolved (TyVarRecord _ fs1), Scalar (Record fs2)) -> + if all (`elem` M.keys fs2) (M.keys fs1) + then unifySharedFields reason bcs fs1 fs2 + else + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with record type" + indent 2 (pretty (Record fs2)) + (Unsolved (TyVarRecord _ fs1), _) -> + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty t) -- -- Internal error cases - (Just (Right TyVarSol {}), _) -> + (Solved {}, _) -> error $ "Type variable already solved: " <> prettyNameString v - (Just (Right TyVarParam {}), _) -> + (Param {}, _) -> error $ "Cannot substitute type parameter: " <> prettyNameString v - (Just Left {}, _) -> - error $ "Type variable already linked: " <> prettyNameString v - (Nothing, _) -> - error $ "subTyVar: Nothing v: " <> prettyNameString v - --- Precondition: 'v' and 't' are both currently flexible. --- --- The purpose of this function is to combine the partial knowledge we --- may have about these two type variables. -unionTyVars :: Reason Type -> BreadCrumbs -> VName -> VName -> SolveM () -unionTyVars reason bcs v t = do - v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars - t_info <- lookupTyVarInfo t - - -- Insert the link from v to t, and then update the info of t based - -- on the existing info of v and t. - setLink v t - - case (v_info, t_info) of - ( TyVarUnsol (TyVarFree _ v_l), - TyVarFree t_loc t_l - ) - | v_l /= t_l -> - setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) - -- When either is completely unconstrained. - (TyVarUnsol TyVarFree {}, _) -> - pure () - ( TyVarUnsol info, - TyVarFree {} - ) -> - setInfo t (TyVarUnsol info) - -- - -- TyVarPrim cases - ( TyVarUnsol (TyVarPrim _ v_pts), - TyVarPrim t_loc t_pts - ) -> - let pts = L.intersect v_pts t_pts - in if null pts - then - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" - indent 2 (pretty v_pts) - "with type that must be one of" - indent 2 (pretty t_pts) - else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) - ( TyVarUnsol (TyVarPrim _ v_pts), - TyVarRecord {} - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" - indent 2 (pretty v_pts) - "with type that must be a record." - ( TyVarUnsol (TyVarPrim _ v_pts), - TyVarSum {} - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type that must be one of" - indent 2 (pretty v_pts) - "with type that must be sum." - -- - -- TyVarSum cases - ( TyVarUnsol (TyVarSum _ cs1), - TyVarSum loc cs2 - ) -> do - unifySharedConstructors reason bcs cs1 cs2 - let cs3 = cs1 <> cs2 - setInfo t (TyVarUnsol (TyVarSum loc cs3)) - ( TyVarUnsol TyVarSum {}, - TyVarPrim _ pts - ) -> - typeError (locOf reason) mempty $ - "A sum type cannot be one of" - indent 2 (pretty pts) - ( TyVarUnsol (TyVarSum _ cs1), - TyVarRecord _ fs - ) -> - typeError (locOf reason) mempty $ - "Cannot unify type with constructors" - indent 2 (pretty (Sum cs1)) - "with type" - indent 2 (pretty (Scalar (Record fs))) - -- - -- TyVarRecord cases - ( TyVarUnsol (TyVarRecord _ fs1), - TyVarRecord loc fs2 - ) -> do - unifySharedFields reason bcs fs1 fs2 - let fs3 = fs1 <> fs2 - setInfo t (TyVarUnsol (TyVarRecord loc fs3)) - ( TyVarUnsol TyVarRecord {}, - TyVarPrim _ pts - ) -> - typeError (locOf reason) mempty $ - "A record type cannot be one of" - indent 2 (pretty pts) - ( TyVarUnsol (TyVarRecord _ fs1), - TyVarSum _ cs - ) -> - typeError (locOf reason) mempty $ - "Cannot unify record type" - indent 2 (pretty (Record fs1)) - "with type" - indent 2 (pretty (Scalar (Sum cs))) - -- - -- Internal error cases - (TyVarSol {}, _) -> - alreadySolved - (TyVarParam {}, _) -> - isParam - where - unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v - alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v - alreadySolved = error $ "Type variable already solved: " <> prettyNameString v - isParam = error $ "Type name is a type parameter: " <> prettyNameString v -unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a -unsharedConstructorsMsg cs1 cs2 = - "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." +solveCt :: CtTy () -> SolveM s () +solveCt (CtEq reason t1 t2) = solveEq reason mempty t1 t2 + +solveEq :: Reason Type -> BreadCrumbs -> Type -> Type -> SolveM s () +solveEq reason obcs orig_t1 orig_t2 = do + solveCt' (obcs, (orig_t1, orig_t2)) where - missing = - filter (`notElem` M.keys cs1) (M.keys cs2) - ++ filter (`notElem` M.keys cs2) (M.keys cs1) + flexible :: VName -> SolveM s (Maybe (TyVarNode s)) + flexible v = do + uf <- asks solverTyVars + case M.lookup v uf of + j_n@(Just node) -> do + sol <- getSol' node + pure $ case sol of + Unsolved _ -> j_n + _ -> Nothing + Nothing -> pure Nothing --- Unify at the root, emitting new equalities that must hold. + normalize :: TypeBase () NoUniqueness -> SolveM s (TypeBase () NoUniqueness) + normalize t@(Scalar (TypeVar _ (QualName [] v) [])) = do + uf <- asks solverTyVars + case M.lookup v uf of + Just node -> do + sol <- getSol' node + case sol of + Solved t' -> normalize t' + _ -> typeVar <$> getKey' node + Nothing -> pure t + normalize t = pure t + + solveCt' :: (BreadCrumbs, (Type, Type)) -> SolveM s () + solveCt' (bcs, (t1, t2)) = do + t1' <- normalize t1 + t2' <- normalize t2 + case (t1', t2') of + ( Scalar (TypeVar _ (QualName [] v1) []), + Scalar (TypeVar _ (QualName [] v2) []) + ) + | v1 == v2 -> pure () + | otherwise -> do + mb_node1 <- flexible v1 + mb_node2 <- flexible v2 + case (mb_node1, mb_node2) of + (Nothing, Nothing) -> + cannotUnify reason mempty bcs t1 t2 + (Just v1_node, Nothing) -> + bindTyVar reason bcs v1 v1_node t2' + (Nothing, Just v2_node) -> + bindTyVar reason bcs v2 v2_node t1' + (Just v1_node, Just v2_node) -> + unionTyVars reason bcs v1 v1_node v2_node + (Scalar (TypeVar _ (QualName [] v1) []), _) -> do + mb_node <- flexible v1 + case mb_node of + Just node -> bindTyVar reason bcs v1 node t2' + Nothing -> tryUnify t1' t2' reason bcs + (_, Scalar (TypeVar _ (QualName [] v2) [])) -> do + mb_node <- flexible v2 + case mb_node of + Just node -> bindTyVar reason bcs v2 node t1' + Nothing -> tryUnify t1' t2' reason bcs + (_, _) -> tryUnify t1' t2' reason bcs + + tryUnify :: Type -> Type -> Reason Type -> BreadCrumbs -> SolveM s () + tryUnify t1 t2 r bcs = + case unify t1 t2 of + Left details -> cannotUnify r (aNote details) bcs t1 t2 + Right eqs -> mapM_ solveCt' eqs + +-- | Unify at the root, emitting new equalities that must hold. unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) | pt1 == pt2 = Right [] @@ -487,11 +377,13 @@ unify where f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) f _ = Nothing -unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = - Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] - where - t1r' = t1r `setUniqueness` NoUniqueness - t2r' = t2r `setUniqueness` NoUniqueness +unify + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = + Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness unify (Scalar (Record fs1)) (Scalar (Record fs2)) | M.keys fs1 == M.keys fs2 = Right $ @@ -529,78 +421,253 @@ unify t1 t2 Right [(mempty, (t1', t2'))] unify _ _ = Left mempty -solveEq :: Reason Type -> BreadCrumbs -> Type -> Type -> SolveM () -solveEq reason obcs orig_t1 orig_t2 = do - solveCt' (obcs, (orig_t1, orig_t2)) +maybeLookupTyVarSol :: TyVar -> SolveM s (Maybe TyVarSol) +maybeLookupTyVarSol tv = do + tyvars <- asks solverTyVars + case M.lookup tv tyvars of + Nothing -> pure Nothing + Just node -> do + sol <- getSol' node + pure $ Just sol + +lookupTyVar :: TyVar -> SolveM s (Either (TyVarInfo ()) Type) +lookupTyVar tv = + maybe bad unpack <$> maybeLookupTyVarSol tv where - solveCt' (bcs, (t1, t2)) = do - tyvars <- gets solverTyVars - let flexible v = case M.lookup v tyvars of - Just (Left v') -> flexible v' - Just (Right (TyVarUnsol _)) -> True - Just (Right TyVarSol {}) -> False - Just (Right TyVarParam {}) -> False - Nothing -> False - sub t@(Scalar (TypeVar u (QualName [] v) [])) = - case M.lookup v tyvars of - Just (Left v') -> sub $ Scalar (TypeVar u (QualName [] v') []) - Just (Right (TyVarSol t')) -> sub t' - _ -> t - sub t = t - case (sub t1, sub t2) of - ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), - t2'@(Scalar (TypeVar _ (QualName [] v2) [])) - ) - | v1 == v2 -> pure () - | otherwise -> - case (flexible v1, flexible v2) of - (False, False) -> cannotUnify reason mempty bcs t1 t2 - (True, False) -> subTyVar reason bcs v1 t2' - (False, True) -> subTyVar reason bcs v2 t1' - (True, True) -> unionTyVars reason bcs v1 v2 - (Scalar (TypeVar _ (QualName [] v1) []), t2') - | flexible v1 -> subTyVar reason bcs v1 t2' - (t1', Scalar (TypeVar _ (QualName [] v2) [])) - | flexible v2 -> subTyVar reason bcs v2 t1' - (t1', t2') -> case unify t1' t2' of - Left details -> cannotUnify reason (aNote details) bcs t1' t2' - Right eqs -> mapM_ solveCt' eqs - -solveCt :: CtTy () -> SolveM () -solveCt ct = - case ct of - CtEq reason t1 t2 -> solveEq reason mempty t1 t2 - -scopeCheck :: Reason Type -> TyVar -> Int -> Type -> SolveM () -scopeCheck reason v v_lvl ty = do - mapM_ check $ typeVars ty + bad = error $ "Unknown tyvar: " <> prettyNameString tv + unpack (Param {}) = error $ "Is a type param: " <> prettyNameString tv + unpack (Solved t) = Right t + unpack (Unsolved info) = Left info + +lookupTyVarInfo :: TyVarNode s -> SolveM s (TyVarInfo ()) +lookupTyVarInfo v_node = do + r <- getSol' v_node + case r of + Unsolved info -> pure info + _ -> do + v <- getKey' v_node + error $ "Tyvar is nonflexible: " <> prettyNameString v + +lookupUF :: TyVar -> SolveM s (TyVarNode s) +lookupUF tv = do + uf <- asks solverTyVars + case M.lookup tv uf of + Nothing -> error $ "Unknown tyvar: " <> prettyNameString tv + Just node -> pure node + +unifySharedFields :: + Reason Type -> + BreadCrumbs -> + M.Map Name Type -> + M.Map Name Type -> + SolveM s () +unifySharedFields reason bcs fs1 fs2 = + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> + solveEq reason (matchingField f <> bcs) ts1 ts2 + +unifySharedConstructors :: + Reason Type -> + BreadCrumbs -> + M.Map Name [Type] -> + M.Map Name [Type] -> + SolveM s () +unifySharedConstructors reason bcs cs1 cs2 = + forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> + if length ts1 == length ts2 + then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 + else + typeError (locOf reason) mempty $ + "Cannot unify type with constructor" + indent 2 (pretty (Sum (M.singleton c ts1))) + "with type of constructor" + indent 2 (pretty (Sum (M.singleton c ts2))) + "because they differ in arity." + +setInfo :: TyVarNode s -> TyVarSol -> SolveM s () +setInfo node sol = liftST $ assignNewSol node sol + +unionTyVars :: + Reason Type -> + BreadCrumbs -> + VName -> + TyVarNode s -> + TyVarNode s -> + SolveM s () +unionTyVars reason bcs v v_node t_node = do + v_sol <- getSol' v_node + t_info <- lookupTyVarInfo t_node + c <- check v_sol t_info + case c of + Left (loc, notes, msg) -> typeError loc notes msg + Right (Just new_sol) -> unionNewSol' v_node t_node new_sol + Right Nothing -> union' v_node t_node where + check :: + TyVarSol -> + TyVarInfo () -> + SolveM s (Either (Loc, Notes, Doc ()) (Maybe TyVarSol)) + check v_sol t_info = + case (v_sol, t_info) of + (Unsolved (TyVarFree _ v_l), TyVarFree t_loc t_l) + | v_l /= t_l -> + pure $ Right $ Just $ Unsolved $ TyVarFree t_loc (min v_l t_l) + (Unsolved info, TyVarFree {}) -> do + pure $ Right $ Just $ Unsolved info + -- + -- TyVarPrim cases + ( Unsolved (TyVarPrim _ v_pts), + TyVarPrim t_loc t_pts + ) -> + let pts = L.intersect v_pts t_pts + in if null pts + then + pure $ + Left + ( locOf reason, + mempty, + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be one of" + indent 2 (pretty t_pts) + ) + else pure $ Right $ Just $ Unsolved $ TyVarPrim t_loc pts + (Unsolved (TyVarPrim _ v_pts), TyVarRecord {}) -> + pure $ + Left + ( locOf reason, + mempty, + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be a record." + ) + (Unsolved (TyVarPrim _ v_pts), TyVarSum {}) -> + pure $ + Left + ( locOf reason, + mempty, + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be sum." + ) + -- + -- TyVarSum cases + ( Unsolved (TyVarSum _ cs1), + TyVarSum loc cs2 + ) -> do + unifySharedConstructors reason bcs cs1 cs2 + let cs3 = cs1 <> cs2 + pure $ Right $ Just $ Unsolved $ TyVarSum loc cs3 + ( Unsolved TyVarSum {}, + TyVarPrim _ pts + ) -> + pure $ + Left + ( locOf reason, + mempty, + "A sum type cannot be one of" + indent 2 (pretty pts) + ) + ( Unsolved (TyVarSum _ cs1), + TyVarRecord _ fs + ) -> + pure $ + Left + ( locOf reason, + mempty, + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Scalar (Record fs))) + ) + -- + -- TyVarRecord cases + ( Unsolved (TyVarRecord _ fs1), + TyVarRecord loc fs2 + ) -> do + unifySharedFields reason bcs fs1 fs2 + let fs3 = fs1 <> fs2 + pure $ Right $ Just $ Unsolved $ TyVarRecord loc fs3 + ( Unsolved TyVarRecord {}, + TyVarPrim _ pts + ) -> + pure $ + Left + ( locOf reason, + mempty, + "A record type cannot be one of" + indent 2 (pretty pts) + ) + ( Unsolved (TyVarRecord _ fs1), + TyVarSum _ cs + ) -> + pure $ + Left + ( locOf reason, + mempty, + "Cannot unify record type" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty (Scalar (Sum cs))) + ) + -- + -- Internal error cases + (Solved {}, _) -> alreadySolved + (Param {}, _) -> isParam + _ -> pure $ Right Nothing + + alreadySolved = error $ "Type variable already solved: " <> prettyNameString v + isParam = error $ "Type name is a type parameter: " <> prettyNameString v + +scopeViolation :: Reason Type -> VName -> Type -> VName -> SolveM s () +scopeViolation reason v1 ty v2 = + typeError (locOf reason) mempty $ + "Cannot unify type" + indent 2 (pretty ty) + "with" + <+> dquotes (prettyName v1) + <+> "(scope violation)." + "This is because" + <+> dquotes (prettyName v2) + <+> "is rigidly bound in a deeper scope." + +scopeCheck :: Reason Type -> TyVar -> Level -> Type -> SolveM s () +scopeCheck reason v v_lvl ty = mapM_ check $ typeVars ty + where + check :: TyVar -> SolveM s () check ty_v = do - ty_v_info <- gets $ M.lookup ty_v . solverTyVars - case ty_v_info of - Just (Right (TyVarParam ty_v_lvl _ _)) - | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v - Just (Right (TyVarSol ty')) -> + maybe (pure ()) checkNode =<< maybeLookupUF ty_v + + checkNode :: TyVarNode s -> SolveM s () + checkNode node = do + sol <- getSol' node + case sol of + Param ty_v_lvl _ _ + | ty_v_lvl > v_lvl -> do + k <- getKey' node + ty' <- substTyVars ty + scopeViolation reason v ty' k + Solved ty' -> do mapM_ check $ typeVars ty' _ -> pure () --- If a type variable has a liftedness constraint, we propagate that +-- | If a type variable has a liftedness constraint, we propagate that -- constraint to its solution. The actual checking for correct usage -- is done later. -liftednessCheck :: Liftedness -> Type -> SolveM () +liftednessCheck :: Liftedness -> Type -> SolveM s () liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do - v_info <- maybeLookupTyVar v + v_info <- maybeLookupTyVarSol v case v_info of Nothing -> -- Is an opaque type. pure () - Just (TyVarSol v_ty) -> + Just (Solved v_ty) -> liftednessCheck l v_ty - Just TyVarParam {} -> pure () - Just (TyVarUnsol (TyVarFree loc v_l)) - | l /= v_l -> - setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) - Just TyVarUnsol {} -> pure () + Just Param {} -> pure () + Just (Unsolved (TyVarFree loc v_l)) + | l < v_l -> do + node <- lookupUF v + setInfo node $ Unsolved $ TyVarFree loc l + Just Unsolved {} -> pure () liftednessCheck _ (Scalar Prim {}) = pure () liftednessCheck Lifted _ = pure () liftednessCheck _ Array {} = pure () @@ -611,7 +678,25 @@ liftednessCheck l (Scalar (Sum cs)) = mapM_ (mapM_ $ liftednessCheck l) cs liftednessCheck _ (Scalar TypeVar {}) = pure () -solveTyVar :: (VName, (Level, TyVarInfo ())) -> SolveM () +solveTyVar :: (VName, (Level, TyVarInfo ())) -> SolveM s () +solveTyVar (tv, (lvl, TyVarFree loc l)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right ty -> do + scopeCheck (Reason loc) tv lvl ty + liftednessCheck l ty + _ -> pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right ty + | ty `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + typeError loc mempty $ + "Numeric constant inferred to be of type" + indent 2 (align (pretty ty)) + "which is not possible." + _ -> pure () solveTyVar (tv, (_, TyVarRecord loc fs1)) = do tv_t <- lookupTyVar tv case tv_t of @@ -622,8 +707,7 @@ solveTyVar (tv, (_, TyVarRecord loc fs1)) = do <+> "is ambiguous." "Must be a record with fields" indent 2 (pretty (Scalar (Record fs1))) - Right _ -> - pure () + Right _ -> pure () solveTyVar (tv, (_, TyVarSum loc cs1)) = do tv_t <- lookupTyVar tv case tv_t of @@ -633,29 +717,60 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () -solveTyVar (tv, (lvl, TyVarFree loc l)) = do - tv_t <- lookupTyVar tv - case tv_t of - Right ty -> do - scopeCheck (Reason loc) tv lvl ty - liftednessCheck l ty - _ -> pure () -solveTyVar (tv, (_, TyVarPrim loc pts)) = do - tv_t <- lookupTyVar tv - case tv_t of - Right (Scalar (Prim ty)) - | [ty] == pts -> - setInfo tv $ TyVarSol $ Scalar $ Prim ty - Right ty - | ty `elem` map (Scalar . Prim) pts -> pure () - | otherwise -> - typeError loc mempty $ - "Numeric constant inferred to be of type" - indent 2 (align (pretty ty)) - "which is not possible." - _ -> pure () --- Print in a way helpful for writing a test case for TySolveTests. +maybeLookupUF :: TyVar -> SolveM s (Maybe (TyVarNode s)) +maybeLookupUF tv = do + uf <- asks solverTyVars + pure . M.lookup tv $ uf + +getSolution :: SolveM s ([UnconTyVar], Solution) +getSolution = do + uf <- asks solverTyVars + resolved <- M.traverseWithKey resolve uf + let unconstrained = M.foldrWithKey unconstr [] resolved + sol = M.mapMaybeWithKey mkSubst resolved + pure (unconstrained, sol) + where + resolve :: + TyVar -> + TyVarNode s -> + SolveM s (Either [PrimType] (TypeBase () NoUniqueness), Maybe Liftedness) + resolve tv node = do + sol <- getSol' node + case sol of + Unsolved (TyVarFree _ l) -> do + k <- getKey' node + let tv' = typeVar k + -- If the current type variable and root type variable are + -- different, this variable is unconstrained, so we save the + -- liftedness constraint for later. + pure (Right tv', if k == tv then Just l else Nothing) + Unsolved (TyVarPrim _ pts) -> pure (Left pts, Nothing) + Solved t -> do + t' <- substTyVars t + pure (Right $ first (const ()) t', Nothing) + _ -> do + k <- getKey' node + pure (Right $ typeVar k, Nothing) + + unconstr :: + TyVar -> + (Either [PrimType] (TypeBase () NoUniqueness), Maybe Liftedness) -> + [UnconTyVar] -> + [UnconTyVar] + unconstr tv (_, Just l) acc = (tv, l) : acc + unconstr _ _ acc = acc + + mkSubst :: + TyVar -> + (Either [PrimType] (TypeBase () NoUniqueness), Maybe Liftedness) -> + Maybe (Either [PrimType] (TypeBase () NoUniqueness)) + mkSubst _ (_, Just _) = Nothing + mkSubst tv (s@(Right (Scalar (TypeVar _ (QualName [] tv') _))), _) = + if tv /= tv' then Just s else Nothing + mkSubst _ (s, _) = Just s + +-- | Print in a way helpful for writing a test case for TySolveTests. logSolution :: [CtTy ()] -> TyParams -> @@ -698,14 +813,12 @@ solve :: TyVars () -> Either TypeError ([UnconTyVar], Solution) solve constraints typarams tyvars = - maybeLog - . second solution - . runExcept - . flip execStateT (initialState typarams tyvars) - . runSolveM - $ do + maybeLog $ runST $ do + r <- initializeState typarams tyvars + flip runReaderT r $ runExceptT $ runSolveM $ do mapM_ solveCt constraints - mapM_ solveTyVar (M.toList tyvars) + mapM_ solveTyVar $ M.toList tyvars + getSolution where maybeLog | isEnvVarAtLeast "FUTHARK_LOG_TYSOLVE" 0 = \s -> diff --git a/src/Language/Futhark/TypeChecker/TySolveOld.hs b/src/Language/Futhark/TypeChecker/TySolveOld.hs new file mode 100644 index 0000000000..de5311e151 --- /dev/null +++ b/src/Language/Futhark/TypeChecker/TySolveOld.hs @@ -0,0 +1,703 @@ +-- | The constraint solver for unsized type equality constraints. +module Language.Futhark.TypeChecker.TySolveOld + ( Type, + Solution, + UnconTyVar, + solve, + ) +where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.State +import Data.Bifunctor +import Data.List qualified as L +import Data.Loc +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Debug.Trace +import Futhark.Util (isEnvVarAtLeast) +import Futhark.Util.Pretty +import Language.Futhark +import Language.Futhark.TypeChecker.Constraints +import Language.Futhark.TypeChecker.Error +import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), aNote, prettyTypeError) +import Language.Futhark.TypeChecker.Types (substTyVars) + +-- | The type representation used by the constraint solver. Agnostic +-- to sizes and uniqueness. +type Type = CtType () + +-- | A (partial) solution for a type variable. +data TyVarSol + = -- | Has been substituted with this. + TyVarSol Type + | -- | Is an explicit (rigid) type parameter in the source program. + TyVarParam Level Liftedness Loc + | -- | Not substituted yet; has this constraint. + TyVarUnsol (TyVarInfo ()) + deriving (Show) + +newtype SolverState = SolverState + { -- | Left means linked to this other type variable. + solverTyVars :: M.Map TyVar (Either VName TyVarSol) + } + +initialState :: TyParams -> TyVars () -> SolverState +initialState typarams tyvars = SolverState $ M.map g typarams <> M.map f tyvars + where + f (_lvl, info) = Right $ TyVarUnsol info + g (lvl, l, loc) = Right $ TyVarParam lvl l loc + +substTyVar :: (Monoid u) => M.Map TyVar (Either VName TyVarSol) -> VName -> Maybe (TypeBase () u) +substTyVar m v = + case M.lookup v m of + Just (Left v') -> substTyVar m v' + Just (Right (TyVarSol t')) -> Just $ second (const mempty) $ substTyVars (substTyVar m) t' + Just (Right TyVarParam {}) -> Nothing + Just (Right (TyVarUnsol {})) -> Nothing + Nothing -> Nothing + +maybeLookupTyVar :: TyVar -> SolveM (Maybe TyVarSol) +maybeLookupTyVar orig = do + tyvars <- gets solverTyVars + let f v = case M.lookup v tyvars of + Nothing -> pure Nothing + Just (Left v') -> f v' + Just (Right info) -> pure $ Just info + f orig + +lookupTyVar :: TyVar -> SolveM (Either (TyVarInfo ()) Type) +lookupTyVar orig = + maybe bad unpack <$> maybeLookupTyVar orig + where + bad = error $ "Unknown tyvar: " <> prettyNameString orig + unpack (TyVarParam {}) = error $ "Is a type param: " <> prettyNameString orig + unpack (TyVarSol t) = Right t + unpack (TyVarUnsol info) = Left info + +-- | Variable must be flexible. +lookupTyVarInfo :: TyVar -> SolveM (TyVarInfo ()) +lookupTyVarInfo v = do + r <- lookupTyVar v + case r of + Left info -> pure info + Right _ -> error $ "Tyvar is nonflexible: " <> prettyNameString v + +setLink :: TyVar -> VName -> SolveM () +setLink v info = modify $ \s -> s {solverTyVars = M.insert v (Left info) $ solverTyVars s} + +setInfo :: TyVar -> TyVarSol -> SolveM () +setInfo v info = modify $ \s -> s {solverTyVars = M.insert v (Right info) $ solverTyVars s} + +-- | A solution maps a type variable to its substitution. This +-- substitution is complete, in the sense there are no right-hand +-- sides that contain a type variable. +type Solution = M.Map TyVar (Either [PrimType] (TypeBase () NoUniqueness)) + +-- | An unconstrained type variable comprises a name and (ironically) +-- a constraint on how it can be instantiated. +type UnconTyVar = (VName, Liftedness) + +typeVar :: (Monoid u) => VName -> TypeBase dim u +typeVar v = Scalar $ TypeVar mempty (qualName v) [] + +solution :: SolverState -> ([UnconTyVar], Solution) +solution s = + ( mapMaybe unconstrained $ M.toList $ solverTyVars s, + M.mapMaybe mkSubst $ solverTyVars s + ) + where + mkSubst (Right (TyVarSol t)) = + Just $ Right $ first (const ()) $ substTyVars (substTyVar (solverTyVars s)) t + mkSubst (Left v') = + Just . fromMaybe (Right $ Scalar $ TypeVar mempty (qualName v') []) $ + mkSubst =<< M.lookup v' (solverTyVars s) + mkSubst (Right (TyVarUnsol (TyVarPrim _ pts))) = Just $ Left pts + mkSubst _ = Nothing + + unconstrained (v, Right (TyVarUnsol (TyVarFree _ l))) = Just (v, l) + unconstrained _ = Nothing + +newtype SolveM a = SolveM {runSolveM :: StateT SolverState (Except TypeError) a} + deriving (Functor, Applicative, Monad, MonadState SolverState, MonadError TypeError) + +-- Try to substitute as much information as we have. +enrichType :: Type -> SolveM Type +enrichType t = do + s <- get + pure $ substTyVars (substTyVar (solverTyVars s)) t + +typeError :: Loc -> Notes -> Doc () -> SolveM () +typeError loc notes msg = + throwError $ TypeError loc notes msg + +occursCheck :: Reason Type -> VName -> Type -> SolveM () +occursCheck reason v tp = do + vars <- gets solverTyVars + let tp' = substTyVars (substTyVar vars) tp + when (v `S.member` typeVars tp') . typeError (locOf reason) mempty $ + "Occurs check: cannot instantiate" + <+> prettyName v + <+> "with" + <+> pretty tp + <> "." + +unifySharedConstructors :: + Reason Type -> + BreadCrumbs -> + M.Map Name [Type] -> + M.Map Name [Type] -> + SolveM () +unifySharedConstructors reason bcs cs1 cs2 = + forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (ts1, ts2)) -> + if length ts1 == length ts2 + then zipWithM_ (solveEq reason $ matchingConstructor c <> bcs) ts1 ts2 + else + typeError (locOf reason) mempty $ + "Cannot unify type with constructor" + indent 2 (pretty (Sum (M.singleton c ts1))) + "with type of constructor" + indent 2 (pretty (Sum (M.singleton c ts2))) + "because they differ in arity." + +unifySharedFields :: + Reason Type -> + BreadCrumbs -> + M.Map Name Type -> + M.Map Name Type -> + SolveM () +unifySharedFields reason bcs fs1 fs2 = + forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> + solveEq reason (matchingField f <> bcs) ts1 ts2 + +scopeViolation :: Reason Type -> VName -> Type -> VName -> SolveM () +scopeViolation reason v1 ty v2 = + typeError (locOf reason) mempty $ + "Cannot unify type" + indent 2 (pretty ty) + "with" + <+> dquotes (prettyName v1) + <+> "(scope violation)." + "This is because" + <+> dquotes (prettyName v2) + <+> "is rigidly bound in a deeper scope." + +cannotUnify :: + Reason Type -> + Notes -> + BreadCrumbs -> + Type -> + Type -> + SolveM () +cannotUnify reason notes bcs t1 t2 = do + t1' <- enrichType t1 + t2' <- enrichType t2 + case reason of + Reason loc -> + typeError loc notes . stack $ + [ "Cannot unify", + indent 2 (pretty t1'), + "with", + indent 2 (pretty t2') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonPatMatch loc pat value_t -> + typeError loc notes . stack $ + [ "Pattern", + indent 2 $ align $ pretty pat, + "cannot match value of type", + indent 2 $ align $ pretty value_t + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonAscription loc expected actual -> + typeError loc notes . stack $ + [ "Expression does not have expected type from type ascription.", + "Expected:" <+> align (pretty expected), + "Actual: " <+> align (pretty actual) + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonRetType loc expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ "Function body does not have expected type.", + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + <> [pretty bcs | not $ hasNoBreadCrumbs bcs] + ReasonApply loc f e expected actual -> do + expected' <- enrichType expected + actual' <- enrichType actual + typeError loc notes . stack $ + [ header, + "Expected:" <+> align (pretty expected'), + "Actual: " <+> align (pretty actual') + ] + where + header = + case f of + (Nothing, _) -> + "Cannot apply function to" + <+> dquotes (shorten $ group $ pretty e) + <> " (invalid type)." + (Just fname, _) -> + "Cannot apply" + <+> dquotes (pretty fname) + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> " (invalid type)." + ReasonApplySplit loc (fname, 0) _ ftype -> + typeError loc notes $ + stack + [ "Cannot apply" + <+> fname' + <+> "as function, as it has non-function type:" + indent 2 (align $ pretty ftype) + ] + where + fname' = maybe "expression" (dquotes . pretty) fname + ReasonApplySplit loc (fname, i) e _ -> + typeError loc notes $ + stack + [ "Cannot apply" + <+> fname' + <+> "to" + <+> dquotes (align $ shorten $ group $ pretty e) + <> ".", + "Function accepts only" <+> pretty i <+> "arguments." + ] + where + fname' = maybe "expression" (dquotes . pretty) fname + ReasonBranches loc former latter -> do + former' <- enrichType former + latter' <- enrichType latter + typeError loc notes . stack $ + [ "Branches differ in type.", + "Former:" <+> pretty former', + "Latter:" <+> pretty latter' + ] + +-- Precondition: 'v' is currently flexible. +subTyVar :: Reason Type -> BreadCrumbs -> VName -> Type -> SolveM () +subTyVar reason bcs v t = do + occursCheck reason v t + v_info <- gets $ M.lookup v . solverTyVars + + -- Set a solution for v, then update info for t in case v has any + -- odd constraints. + setInfo v (TyVarSol t) + + case (v_info, t) of + (Just (Right (TyVarUnsol TyVarFree {})), _) -> pure () + (Just (Right (TyVarUnsol (TyVarPrim _ v_pts))), _) -> + if t `elem` map (Scalar . Prim) v_pts + then pure () + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot instantiate type that must be one of" + indent 2 (pretty v_pts) + "with" + indent 2 (pretty t) + (Just (Right (TyVarUnsol (TyVarSum _ cs1))), Scalar (Sum cs2)) -> + if all (`elem` M.keys cs2) (M.keys cs1) + then unifySharedConstructors reason bcs cs1 cs2 + else cannotUnify reason notes bcs (typeVar v) t + where + notes = + aNote $ + "Cannot match type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs1))) + "with type with constructors" + indent 2 (stack (map (("#" <>) . pretty) (M.keys cs2))) + unsharedConstructorsMsg cs1 cs2 + (Just (Right (TyVarUnsol (TyVarSum _ cs1))), _) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty t) + (Just (Right (TyVarUnsol (TyVarRecord _ fs1))), Scalar (Record fs2)) -> + if all (`elem` M.keys fs2) (M.keys fs1) + then unifySharedFields reason bcs fs1 fs2 + else + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with record type" + indent 2 (pretty (Record fs2)) + (Just (Right (TyVarUnsol (TyVarRecord _ fs1))), _) -> + typeError (locOf reason) mempty $ + "Cannot unify record type with fields" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty t) + -- + -- Internal error cases + (Just (Right TyVarSol {}), _) -> + error $ "Type variable already solved: " <> prettyNameString v + (Just (Right TyVarParam {}), _) -> + error $ "Cannot substitute type parameter: " <> prettyNameString v + (Just Left {}, _) -> + error $ "Type variable already linked: " <> prettyNameString v + (Nothing, _) -> + error $ "subTyVar: Nothing v: " <> prettyNameString v + +-- Precondition: 'v' and 't' are both currently flexible. +-- +-- The purpose of this function is to combine the partial knowledge we +-- may have about these two type variables. +unionTyVars :: Reason Type -> BreadCrumbs -> VName -> VName -> SolveM () +unionTyVars reason bcs v t = do + v_info <- gets $ either alreadyLinked id . fromMaybe unknown . M.lookup v . solverTyVars + t_info <- lookupTyVarInfo t + + -- Insert the link from v to t, and then update the info of t based + -- on the existing info of v and t. + setLink v t + + case (v_info, t_info) of + ( TyVarUnsol (TyVarFree _ v_l), + TyVarFree t_loc t_l + ) + | v_l /= t_l -> + setInfo t $ TyVarUnsol $ TyVarFree t_loc (min v_l t_l) + -- When either is completely unconstrained. + (TyVarUnsol TyVarFree {}, _) -> + pure () + ( TyVarUnsol info, + TyVarFree {} + ) -> + setInfo t (TyVarUnsol info) + -- + -- TyVarPrim cases + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarPrim t_loc t_pts + ) -> + let pts = L.intersect v_pts t_pts + in if null pts + then + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be one of" + indent 2 (pretty t_pts) + else setInfo t (TyVarUnsol (TyVarPrim t_loc pts)) + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarRecord {} + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be a record." + ( TyVarUnsol (TyVarPrim _ v_pts), + TyVarSum {} + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type that must be one of" + indent 2 (pretty v_pts) + "with type that must be sum." + -- + -- TyVarSum cases + ( TyVarUnsol (TyVarSum _ cs1), + TyVarSum loc cs2 + ) -> do + unifySharedConstructors reason bcs cs1 cs2 + let cs3 = cs1 <> cs2 + setInfo t (TyVarUnsol (TyVarSum loc cs3)) + ( TyVarUnsol TyVarSum {}, + TyVarPrim _ pts + ) -> + typeError (locOf reason) mempty $ + "A sum type cannot be one of" + indent 2 (pretty pts) + ( TyVarUnsol (TyVarSum _ cs1), + TyVarRecord _ fs + ) -> + typeError (locOf reason) mempty $ + "Cannot unify type with constructors" + indent 2 (pretty (Sum cs1)) + "with type" + indent 2 (pretty (Scalar (Record fs))) + -- + -- TyVarRecord cases + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarRecord loc fs2 + ) -> do + unifySharedFields reason bcs fs1 fs2 + let fs3 = fs1 <> fs2 + setInfo t (TyVarUnsol (TyVarRecord loc fs3)) + ( TyVarUnsol TyVarRecord {}, + TyVarPrim _ pts + ) -> + typeError (locOf reason) mempty $ + "A record type cannot be one of" + indent 2 (pretty pts) + ( TyVarUnsol (TyVarRecord _ fs1), + TyVarSum _ cs + ) -> + typeError (locOf reason) mempty $ + "Cannot unify record type" + indent 2 (pretty (Record fs1)) + "with type" + indent 2 (pretty (Scalar (Sum cs))) + -- + -- Internal error cases + (TyVarSol {}, _) -> + alreadySolved + (TyVarParam {}, _) -> + isParam + where + unknown = error $ "unionTyVars: Nothing v: " <> prettyNameString v + alreadyLinked = error $ "Type variable already linked: " <> prettyNameString v + alreadySolved = error $ "Type variable already solved: " <> prettyNameString v + isParam = error $ "Type name is a type parameter: " <> prettyNameString v + +unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a +unsharedConstructorsMsg cs1 cs2 = + "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." + where + missing = + filter (`notElem` M.keys cs1) (M.keys cs2) + ++ filter (`notElem` M.keys cs2) (M.keys cs1) + +-- Unify at the root, emitting new equalities that must hold. +unify :: Type -> Type -> Either (Doc a) [(BreadCrumbs, (Type, Type))] +unify (Scalar (Prim pt1)) (Scalar (Prim pt2)) + | pt1 == pt2 = Right [] +unify + (Scalar (TypeVar _ (QualName _ v1) targs1)) + (Scalar (TypeVar _ (QualName _ v2) targs2)) + | v1 == v2 = + Right $ mapMaybe f $ zip targs1 targs2 + where + f (TypeArgType t1, TypeArgType t2) = Just (mempty, (t1, t2)) + f _ = Nothing +unify (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) = + Right [(mempty, (t1a, t2a)), (mempty, (t1r', t2r'))] + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness +unify (Scalar (Record fs1)) (Scalar (Record fs2)) + | M.keys fs1 == M.keys fs2 = + Right $ + map (first matchingField) $ + M.toList $ + M.intersectionWith (,) fs1 fs2 + | Just n1 <- length <$> areTupleFields fs1, + Just n2 <- length <$> areTupleFields fs2, + n1 /= n2 = + Left $ + "Tuples have" + <+> pretty n1 + <+> "and" + <+> pretty n2 + <+> "elements respectively." + | otherwise = + let missing = + filter (`notElem` M.keys fs1) (M.keys fs2) + <> filter (`notElem` M.keys fs2) (M.keys fs1) + in Left $ + "unshared fields:" <+> commasep (map pretty missing) <> "." +unify (Scalar (Sum cs1)) (Scalar (Sum cs2)) + | M.keys cs1 == M.keys cs2 = + fmap concat . forM cs' $ \(c, (ts1, ts2)) -> do + if length ts1 == length ts2 + then Right $ zipWith (curry (matchingConstructor c,)) ts1 ts2 + else Left mempty + | otherwise = + Left $ unsharedConstructorsMsg cs1 cs2 + where + cs' = M.toList $ M.intersectionWith (,) cs1 cs2 +unify t1 t2 + | Just t1' <- peelArray 1 t1, + Just t2' <- peelArray 1 t2 = + Right [(mempty, (t1', t2'))] +unify _ _ = Left mempty + +solveEq :: Reason Type -> BreadCrumbs -> Type -> Type -> SolveM () +solveEq reason obcs orig_t1 orig_t2 = do + solveCt' (obcs, (orig_t1, orig_t2)) + where + solveCt' (bcs, (t1, t2)) = do + tyvars <- gets solverTyVars + let flexible v = case M.lookup v tyvars of + Just (Left v') -> flexible v' + Just (Right (TyVarUnsol _)) -> True + Just (Right TyVarSol {}) -> False + Just (Right TyVarParam {}) -> False + Nothing -> False + normalize t@(Scalar (TypeVar u (QualName [] v) [])) = + case M.lookup v tyvars of + Just (Left v') -> normalize $ Scalar (TypeVar u (QualName [] v') []) + Just (Right (TyVarSol t')) -> normalize t' + _ -> t + normalize t = t + case (normalize t1, normalize t2) of + ( t1'@(Scalar (TypeVar _ (QualName [] v1) [])), + t2'@(Scalar (TypeVar _ (QualName [] v2) [])) + ) + | v1 == v2 -> pure () + | otherwise -> + case (flexible v1, flexible v2) of + (False, False) -> cannotUnify reason mempty bcs t1 t2 + (True, False) -> subTyVar reason bcs v1 t2' + (False, True) -> subTyVar reason bcs v2 t1' + (True, True) -> unionTyVars reason bcs v1 v2 + (Scalar (TypeVar _ (QualName [] v1) []), t2') + | flexible v1 -> subTyVar reason bcs v1 t2' + (t1', Scalar (TypeVar _ (QualName [] v2) [])) + | flexible v2 -> subTyVar reason bcs v2 t1' + (t1', t2') -> case unify t1' t2' of + Left details -> cannotUnify reason (aNote details) bcs t1' t2' + Right eqs -> mapM_ solveCt' eqs + +solveCt :: CtTy () -> SolveM () +solveCt ct = + case ct of + CtEq reason t1 t2 -> solveEq reason mempty t1 t2 + +scopeCheck :: Reason Type -> TyVar -> Int -> Type -> SolveM () +scopeCheck reason v v_lvl ty = mapM_ check $ typeVars ty + where + check ty_v = do + ty_v_info <- gets $ M.lookup ty_v . solverTyVars + case ty_v_info of + Just (Right (TyVarParam ty_v_lvl _ _)) + -- Type parameter has a higher level than the (free) type variable. + | ty_v_lvl > v_lvl -> scopeViolation reason v ty ty_v + Just (Right (TyVarSol ty')) -> + mapM_ check $ typeVars ty' + _ -> pure () + +-- If a type variable has a liftedness constraint, we propagate that +-- constraint to its solution. The actual checking for correct usage +-- is done later. +liftednessCheck :: Liftedness -> Type -> SolveM () +liftednessCheck l (Scalar (TypeVar _ (QualName [] v) _)) = do + v_info <- maybeLookupTyVar v + case v_info of + Nothing -> + -- Is an opaque type. + pure () + Just (TyVarSol v_ty) -> + liftednessCheck l v_ty + Just TyVarParam {} -> pure () + Just (TyVarUnsol (TyVarFree loc v_l)) + | l /= v_l -> + setInfo v $ TyVarUnsol $ TyVarFree loc (min l v_l) + Just TyVarUnsol {} -> pure () +liftednessCheck _ (Scalar Prim {}) = pure () +liftednessCheck Lifted _ = pure () +liftednessCheck _ Array {} = pure () +liftednessCheck _ (Scalar Arrow {}) = pure () +liftednessCheck l (Scalar (Record fs)) = + mapM_ (liftednessCheck l) fs +liftednessCheck l (Scalar (Sum cs)) = + mapM_ (mapM_ $ liftednessCheck l) cs +liftednessCheck _ (Scalar TypeVar {}) = pure () + +solveTyVar :: (VName, (Level, TyVarInfo ())) -> SolveM () +solveTyVar (tv, (_, TyVarRecord loc fs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left _ -> + typeError loc mempty $ + "Type" + <+> prettyName tv + <+> "is ambiguous." + "Must be a record with fields" + indent 2 (pretty (Scalar (Record fs1))) + Right _ -> + pure () +solveTyVar (tv, (_, TyVarSum loc cs1)) = do + tv_t <- lookupTyVar tv + case tv_t of + Left _ -> + typeError loc mempty $ + "Type is ambiguous." + "Must be a sum type with constructors" + indent 2 (pretty (Scalar (Sum cs1))) + Right _ -> pure () +solveTyVar (tv, (lvl, TyVarFree loc l)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right ty -> do + scopeCheck (Reason loc) tv lvl ty + liftednessCheck l ty + _ -> pure () +solveTyVar (tv, (_, TyVarPrim loc pts)) = do + tv_t <- lookupTyVar tv + case tv_t of + Right (Scalar (Prim ty)) + | [ty] == pts -> + setInfo tv $ TyVarSol $ Scalar $ Prim ty + Right ty + | ty `elem` map (Scalar . Prim) pts -> pure () + | otherwise -> + typeError loc mempty $ + "Numeric constant inferred to be of type" + indent 2 (align (pretty ty)) + "which is not possible." + _ -> pure () + +-- Print in a way helpful for writing a test case for TySolveTests. +logSolution :: + [CtTy ()] -> + TyParams -> + TyVars () -> + Either TypeError ([UnconTyVar], Solution) -> + String +logSolution constraints typarams tyvars s = + unlines $ + ["# TySolve.solve", "## constraints"] + <> map ppConstraint constraints + <> [ "## typarams", + if typarams == mempty then "mempty" else show $ map ppTyParam (M.toList typarams) + ] + <> [ "## tyvars", + show $ map (bimap prettyNameString (second onTyVar)) $ M.toList tyvars, + either + (("## error\n" <>) . docString . prettyTypeError) + ( ("## solution\n" <>) + . show + . bimap + (map (first prettyNameString)) + (map (bimap prettyNameString $ bimap prettyString prettyString) . M.toList) + ) + s + ] + where + ppConstraint (CtEq _ t1 t2) = + unwords [show (prettyString t1), "~", show (prettyString t2)] + ppTyParam (p, (lvl, info, _)) = show (prettyNameString p, (lvl, info, NoLoc)) + onTyVar (TyVarFree _ l) = TyVarFree NoLoc l + onTyVar (TyVarPrim _ pts) = TyVarPrim NoLoc pts + onTyVar (TyVarRecord _ ts) = TyVarRecord NoLoc ts + onTyVar (TyVarSum _ ts) = TyVarSum NoLoc ts + +-- | Solve type constraints, producing either an error or a solution, +-- alongside a list of unconstrained type variables. +solve :: + [CtTy ()] -> + TyParams -> + TyVars () -> + Either TypeError ([UnconTyVar], Solution) +solve constraints typarams tyvars = + maybeLog + . second solution + . runExcept + . flip execStateT (initialState typarams tyvars) + . runSolveM + $ do + mapM_ solveCt constraints + mapM_ solveTyVar (M.toList tyvars) + where + maybeLog + | isEnvVarAtLeast "FUTHARK_LOG_TYSOLVE" 0 = \s -> + trace (logSolution constraints typarams tyvars s) s + | otherwise = id +{-# NOINLINE solve #-} diff --git a/src/Language/Futhark/TypeChecker/UnionFind.hs b/src/Language/Futhark/TypeChecker/UnionFind.hs new file mode 100644 index 0000000000..e718c6db2b --- /dev/null +++ b/src/Language/Futhark/TypeChecker/UnionFind.hs @@ -0,0 +1,137 @@ +module Language.Futhark.TypeChecker.UnionFind + ( TyVarNode, + TyVarSol (..), + makeTyVarNode, + makeTyParamNode, + find, + getSol, + getKey, + assignNewSol, + union, + unionNewSol, + ) +where + +import Control.Monad (when) +import Control.Monad.ST (ST) +import Data.STRef + ( STRef, + modifySTRef', + newSTRef, + readSTRef, + writeSTRef, + ) +import Language.Futhark (Liftedness, Loc) +import Language.Futhark.TypeChecker.Constraints + ( CtType, + Level, + TyVar, + TyVarInfo, + ) + +type Type = CtType () + +-- | A (partial) solution for a type variable. +data TyVarSol + = -- | Has been assigned this type. + Solved Type + | -- | Is an explicit (rigid) type parameter in the source program. + Param Level Liftedness Loc + | -- | Is unsolved but has this constraint. + Unsolved (TyVarInfo ()) + deriving (Show, Eq) + +-- | A node in the union-find graph containing information about a type +-- variable. +newtype TyVarNode s = Node (STRef s (NodeInfo s)) deriving (Eq) + +data NodeInfo s + = Link !(TyVarNode s) + | Repr !ReprInfo + +data ReprInfo = ReprInfo + { solution :: !TyVarSol, + key :: !TyVar + } + +-- | Create a fresh node of a type variable and return it. A fresh node +-- is in the equivalence class that contains only itself. +makeTyVarNode :: TyVar -> TyVarInfo () -> ST s (TyVarNode s) +makeTyVarNode tv constraint = do + let r = + ReprInfo + { solution = Unsolved constraint, + key = tv + } + ref <- newSTRef $ Repr r + pure $ Node ref + +-- | Create a fresh node of a type parameter and return it. A fresh node +-- is in the equivalence class that contains only itself. +makeTyParamNode :: TyVar -> Level -> Liftedness -> Loc -> ST s (TyVarNode s) +makeTyParamNode tv lvl lft loc = do + let r = + ReprInfo + { solution = Param lvl lft loc, + key = tv + } + ref <- newSTRef $ Repr r + pure $ Node ref + +-- | @find node@ returns the representative of @node@'s +-- equivalence class and the information associated with +-- this equivalence class. +-- +-- This method performs the path compresssion. +find :: TyVarNode s -> ST s (TyVarNode s, ReprInfo) +find node@(Node ref) = do + node_info <- readSTRef ref + case node_info of + -- Input node is representative. + Repr repr_info -> pure (node, repr_info) + -- Input node's parent is another node. + Link parent -> do + a@(repr, _) <- find parent + when (repr /= parent) $ + -- Performing path compression. + writeSTRef ref $ + Link repr + pure a + +-- | Return the solution associated with the argument node's +-- equivalence class. +getSol :: TyVarNode s -> ST s TyVarSol +getSol node = solution . snd <$> find node + +-- | Return the name of the representative type variable. +getKey :: TyVarNode s -> ST s TyVar +getKey node = key . snd <$> find node + +-- | Assign a new solution/type to the node's equivalence class. +-- +-- Precondition: The node is in an equivalence class representing an +-- unsolved/flexible type variable. +assignNewSol :: TyVarNode s -> TyVarSol -> ST s () +assignNewSol node new_sol = do + (Node ref, repr_info) <- find node + modifySTRef' ref $ const . Repr $ repr_info {solution = new_sol} + +-- | Join the equivalence classes of the nodes. The resulting equivalence +-- class has the same solution and key as the second argument. +union :: TyVarNode s -> TyVarNode s -> ST s () +union n1 n2 = do + Node ref <- fst <$> find n1 + root2 <- fst <$> find n2 + + writeSTRef ref $ Link root2 + +-- | Join the equivalence classes of the nodes. The resulting equivalence +-- class has the same key as the second argument while @new_sol@ is the +-- new solution. +unionNewSol :: TyVarNode s -> TyVarNode s -> TyVarSol -> ST s () +unionNewSol n1 n2 new_sol = do + Node ref1 <- fst <$> find n1 + (root2@(Node ref2), repr_info) <- find n2 + + modifySTRef' ref2 $ const . Repr $ repr_info {solution = new_sol} + writeSTRef ref1 $ Link root2 From 5ff3c67815e486b134528aeec6fceaa0887a0956 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 30 Jul 2025 14:49:55 +0200 Subject: [PATCH 292/296] Constraint-based type checking without AUTOMAP. I'm not sure when we will be able to finish AUTOMAP, but the new type checker should be doable. AUTOMAP can then be built on top of this. --- default.nix | 4 - futhark.cabal | 11 - nix/glpk-hs.nix | 23 - prelude/soacs.fut | 2 +- prelude/zip.fut | 18 +- shell.nix | 1 - .../Futhark/Solve/BranchAndBoundTests.hs | 143 ----- src-testing/Futhark/Solve/SimplexTests.hs | 221 -------- src-testing/futhark_tests.hs | 4 - src/Futhark/Solve/BranchAndBound.hs | 74 --- src/Futhark/Solve/GLPK.hs | 60 --- src/Futhark/Solve/LP.hs | 336 ------------ src/Futhark/Solve/Matrix.hs | 330 ------------ src/Futhark/Solve/Simplex.hs | 235 --------- src/Language/Futhark/Interpreter.hs | 12 + src/Language/Futhark/Prop.hs | 10 + .../Futhark/TypeChecker/Constraints.hs | 10 + src/Language/Futhark/TypeChecker/Rank.hs | 493 +----------------- src/Language/Futhark/TypeChecker/Terms.hs | 8 + .../Futhark/TypeChecker/Terms/Unsized.hs | 56 +- src/Language/Futhark/TypeChecker/TySolve.hs | 6 +- tests/automap/ambiguous0.fut | 4 - tests/automap/bool1.fut | 6 - tests/automap/combinations.fut | 38 -- tests/automap/equality1.fut | 23 - tests/automap/lambda.fut | 6 - tests/automap/leetcode.fut | 4 - tests/automap/map0.fut | 8 - tests/automap/mri-q-qr.fut | 2 - tests/automap/mri-q.fut | 41 -- tests/automap/operator1.fut | 9 - tests/automap/optionpricing.fut | 78 --- tests/automap/pagerank.fut | 18 - tests/automap/project.fut | 9 - tests/automap/projsec1.fut | 9 - tests/automap/same_typevar.fut | 16 - tests/automap/sgemm.fut | 32 -- tests/automap/simple1.fut | 7 - tests/automap/simple2.fut | 8 - tests/automap/simple3.fut | 8 - tests/automap/simple4.fut | 8 - tests/automap/simple5.fut | 6 - tests/issue1599.fut | 4 + tests/issue1926.fut | 5 +- tests/types/inference5.fut | 7 + 45 files changed, 103 insertions(+), 2310 deletions(-) delete mode 100644 nix/glpk-hs.nix delete mode 100644 src-testing/Futhark/Solve/BranchAndBoundTests.hs delete mode 100644 src-testing/Futhark/Solve/SimplexTests.hs delete mode 100644 src/Futhark/Solve/BranchAndBound.hs delete mode 100644 src/Futhark/Solve/GLPK.hs delete mode 100644 src/Futhark/Solve/LP.hs delete mode 100644 src/Futhark/Solve/Matrix.hs delete mode 100644 src/Futhark/Solve/Simplex.hs delete mode 100644 tests/automap/ambiguous0.fut delete mode 100644 tests/automap/bool1.fut delete mode 100644 tests/automap/combinations.fut delete mode 100644 tests/automap/equality1.fut delete mode 100644 tests/automap/lambda.fut delete mode 100644 tests/automap/leetcode.fut delete mode 100644 tests/automap/map0.fut delete mode 100644 tests/automap/mri-q-qr.fut delete mode 100644 tests/automap/mri-q.fut delete mode 100644 tests/automap/operator1.fut delete mode 100644 tests/automap/optionpricing.fut delete mode 100644 tests/automap/pagerank.fut delete mode 100644 tests/automap/project.fut delete mode 100644 tests/automap/projsec1.fut delete mode 100644 tests/automap/same_typevar.fut delete mode 100644 tests/automap/sgemm.fut delete mode 100644 tests/automap/simple1.fut delete mode 100644 tests/automap/simple2.fut delete mode 100644 tests/automap/simple3.fut delete mode 100644 tests/automap/simple4.fut delete mode 100644 tests/automap/simple5.fut create mode 100644 tests/issue1599.fut create mode 100644 tests/types/inference5.fut diff --git a/default.nix b/default.nix index 5821f6ca21..c43240b135 100644 --- a/default.nix +++ b/default.nix @@ -37,9 +37,6 @@ let gasp = haskellPackagesNew.callPackage ./nix/gasp.nix {}; - glpk-hs = - haskellPackagesNew.callPackage ./nix/glpk-hs.nix {}; - futhark = # callCabal2Nix does not do a great job at determining # which files must be included as source, which causes @@ -78,7 +75,6 @@ let "--extra-lib-dirs=${pkgs.glibc.static}/lib" "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" - "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { dontDisableStatic = true; })}/lib" # The ones below are due to GHC's runtime system # depending on libdw (DWARF info), which depends on # a bunch of compression algorithms. diff --git a/futhark.cabal b/futhark.cabal index d72eb9a204..2ea03448c9 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -382,11 +382,6 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script - Futhark.Solve.GLPK - Futhark.Solve.LP - Futhark.Solve.Matrix - Futhark.Solve.Simplex - Futhark.Solve.BranchAndBound Futhark.Test Futhark.Test.Spec Futhark.Test.Values @@ -511,9 +506,6 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 - -- remove me later - , glpk-hs - , silently executable futhark import: common @@ -549,8 +541,6 @@ library futhark-testing Futhark.Optimise.ArrayLayoutTests Futhark.Pkg.SolveTests Futhark.ProfileTests - Futhark.Solve.BranchAndBoundTests - Futhark.Solve.SimplexTests Language.Futhark.CoreTests Language.Futhark.PrettyTests Language.Futhark.ParserBenchmarks @@ -575,7 +565,6 @@ library futhark-testing , tasty-hunit , tasty-quickcheck , text - , vector >=0.12 , srcloc , regex-tdfa ^>= 1.3.2 diff --git a/nix/glpk-hs.nix b/nix/glpk-hs.nix deleted file mode 100644 index 189135ed22..0000000000 --- a/nix/glpk-hs.nix +++ /dev/null @@ -1,23 +0,0 @@ -{ mkDerivation, array, base, containers, deepseq, fetchgit, gasp -, glpk, lib, mtl -}: -mkDerivation { - pname = "glpk-hs"; - version = "0.8"; - src = fetchgit { - url = "https://github.com/jyp/glpk-hs.git"; - sha256 = "sha256-AY9wmmqzafpocUspQAvHjDkT4vty5J3GcSOt5qItnlo="; - rev = "1f276aa19861203ea8367dc27a6ad4c8a31c9062"; - fetchSubmodules = true; - }; - isLibrary = true; - isExecutable = true; - libraryHaskellDepends = [ array base containers deepseq gasp mtl ]; - librarySystemDepends = [ glpk ]; - executableHaskellDepends = [ - array base containers deepseq gasp mtl - ]; - description = "Comprehensive GLPK linear programming bindings"; - license = lib.licenses.bsd3; - mainProgram = "glpk-hs-example"; -} diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 71ee9ed5bf..ea3cc90614 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = - f as + intrinsics.map f as -- | Apply the given function to each element of a single array. -- diff --git a/prelude/zip.fut b/prelude/zip.fut index 5ccbacc17b..cf57c71f09 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -6,6 +6,12 @@ -- The main reason this module exists is that we need it to define -- SOACs like `map2`. +-- We need a map to define some of the zip variants, but this file is +-- depended upon by soacs.fut. So we just define a quick-and-dirty +-- internal one here that uses the intrinsic version. +local +def internal_map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = + intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = @@ -17,15 +23,15 @@ def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = -- | As `zip2`@term, but with one more array. def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n](a, b, c) = - (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) + internal_map (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n](a, b, c, d) = - (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) + internal_map (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n](a, b, c, d, e) = - (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) + internal_map (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. def unzip [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = @@ -37,18 +43,18 @@ def unzip2 [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = -- | As `unzip2`@term, but with one more array. def unzip3 [n] 'a 'b 'c (xs: [n](a, b, c)) : ([n]a, [n]b, [n]c) = - let (as, bcs) = unzip ((\(a, b, c) -> (a, (b, c))) xs) + let (as, bcs) = unzip (internal_map (\(a, b, c) -> (a, (b, c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. def unzip4 [n] 'a 'b 'c 'd (xs: [n](a, b, c, d)) : ([n]a, [n]b, [n]c, [n]d) = - let (as, bs, cds) = unzip3 ((\(a, b, c, d) -> (a, b, (c, d))) xs) + let (as, bs, cds) = unzip3 (internal_map (\(a, b, c, d) -> (a, b, (c, d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a, b, c, d, e)) : ([n]a, [n]b, [n]c, [n]d, [n]e) = - let (as, bs, cs, des) = unzip4 ((\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) + let (as, bs, cs, des) = unzip4 (internal_map (\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es) diff --git a/shell.nix b/shell.nix index 72e1cac32e..8f563ed276 100644 --- a/shell.nix +++ b/shell.nix @@ -43,7 +43,6 @@ pkgs.stdenv.mkDerivation { niv ispc imagemagick # needed for literate tests - glpk ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers diff --git a/src-testing/Futhark/Solve/BranchAndBoundTests.hs b/src-testing/Futhark/Solve/BranchAndBoundTests.hs deleted file mode 100644 index b7e1bfe027..0000000000 --- a/src-testing/Futhark/Solve/BranchAndBoundTests.hs +++ /dev/null @@ -1,143 +0,0 @@ -{-# OPTIONS_GHC -fno-warn-type-defaults #-} - -module Futhark.Solve.BranchAndBoundTests - ( tests, - ) -where - -import Data.Vector.Unboxed qualified as V -import Futhark.Solve.BranchAndBound -import Futhark.Solve.LP -import Futhark.Solve.Matrix qualified as M -import Test.Tasty -import Test.Tasty.HUnit -import Prelude hiding (or) - -tests :: TestTree -tests = - testGroup - "BranchAndBoundTests" - [ -- testCase "1" $ - -- let lpe = - -- LPE - -- { pc = V.fromList [1, 1, 0, 0, 0], - -- pA = - -- M.fromLists - -- [ [-1, 1, 1, 0, 0], - -- [1, 0, 0, 1, 0], - -- [0, 1, 0, 0, 1] - -- ], - -- pd = V.fromList [1, 3, 2] - -- } - -- in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), - testCase "2" $ - let lp = - LP - { lpc = V.fromList [40, 30], - lpA = - M.fromLists - [ [1, 1], - [2, 1] - ], - lpd = V.fromList [12, 16] - } - in branchAndBound lp @?= Just (400 :: Double, V.fromList [4, 8]), - testCase "3" $ - let lp = - LP - { lpc = V.fromList [1, 2, 3], - lpA = - M.fromLists - [ [1, 1, 1], - [2, 1, 3] - ], - lpd = V.fromList [12, 18] - } - in branchAndBound lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), - testCase "4" $ - let lp = - LP - { lpc = V.fromList [5.5, 2.1], - lpA = - M.fromLists - [ [-1, 1], - [8, 2] - ], - lpd = V.fromList [2, 17] - } - in assertBool (show $ branchAndBound lp) $ - case branchAndBound lp of - Nothing -> False - Just (z, sol) -> - (z `approxEq` (11.8 :: Double)) - && and (zipWith (==) (V.toList sol) [1, 3]), - -- testCase "5" $ - -- let prog = - -- LinearProg - -- { optType = Maximize, - -- objective = var "x1" ~+~ var "x2", - -- constraints = - -- [ var "x1" ~<=~ constant 10, - -- var "x2" ~<=~ constant 5 - -- ] - -- <> oneIsZero ("b1", "x1") ("b2", "x2") - -- } - -- (lp, _idxmap) = linearProgToLP prog - -- in assertBool - -- (unlines [show $ branchAndBound lp]) - -- $ case branchAndBound lp of - -- Nothing -> False - -- Just (z, _sol) -> - -- and - -- [ z `approxEq` (10 :: Double) - -- ], - -- testCase "6" $ - -- let prog = - -- LinearProg - -- { optType = Maximize, - -- objective = var "x1" ~+~ var "x2", - -- constraints = - -- [ var "x1" ~<=~ constant 10, - -- var "x2" ~<=~ constant 5 - -- ] - -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) - -- } - -- (lp, idxmap) = linearProgToLP prog - -- lpe = convert lp - -- in assertBool - -- (unlines [show $ branchAndBound lp]) - -- $ case branchAndBound lp of - -- Nothing -> False - -- Just (z, sol) -> - -- and - -- [ z `approxEq` (10 :: Double) - -- ] - - testCase "10" $ - let prog = - LinearProg - { optType = Minimize, - objective = var "R2" ~+~ var "M3", - constraints = - [ var "artifical4" ~==~ constant 1 ~+~ var "t0", - constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", - var "b_R2" ~<=~ constant 1, - var "b_M3" ~<=~ constant 1, - var "R2" ~<=~ 1000 ~*~ var "b_R2", - var "M3" ~<=~ 1000 ~*~ var "b_M3", - var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 - ] - } - (lp, _idxmap) = linearProgToLP prog - in assertBool - (unlines [show $ branchAndBound lp]) - $ case branchAndBound lp of - Nothing -> False - Just (z, _sol) -> - and - [ z `approxEq` (0 :: Double) - ] - ] - -approxEq :: (Fractional a, Ord a) => a -> a -> Bool -approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/src-testing/Futhark/Solve/SimplexTests.hs b/src-testing/Futhark/Solve/SimplexTests.hs deleted file mode 100644 index c29bd10a93..0000000000 --- a/src-testing/Futhark/Solve/SimplexTests.hs +++ /dev/null @@ -1,221 +0,0 @@ -{-# OPTIONS_GHC -fno-warn-type-defaults #-} - -module Futhark.Solve.SimplexTests - ( tests, - ) -where - -import Data.Vector.Unboxed qualified as V -import Futhark.Solve.LP -import Futhark.Solve.Matrix qualified as M -import Futhark.Solve.Simplex -import Test.Tasty -import Test.Tasty.HUnit -import Prelude hiding (or) - -tests :: TestTree -tests = - testGroup - "SimplexTests" - [ testCase "1" $ - let lpe = - LPE - { pc = V.fromList [1, 1, 0, 0, 0], - pA = - M.fromLists - [ [-1, 1, 1, 0, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 0, 1] - ], - pd = V.fromList [1, 3, 2] - } - in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), - testCase "2" $ - let lp = - LP - { lpc = V.fromList [40, 30], - lpA = - M.fromLists - [ [1, 1], - [2, 1] - ], - lpd = V.fromList [12, 16] - } - in simplexLP lp @?= Just (400 :: Double, V.fromList [4, 8]), - testCase "3" $ - let lp = - LP - { lpc = V.fromList [1, 2, 3], - lpA = - M.fromLists - [ [1, 1, 1], - [2, 1, 3] - ], - lpd = V.fromList [12, 18] - } - in simplexLP lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), - testCase "4" $ - let lp = - LP - { lpc = V.fromList [5.5, 2.1], - lpA = - M.fromLists - [ [-1, 1], - [8, 2] - ], - lpd = V.fromList [2, 17] - } - in assertBool (show $ simplexLP lp) $ - case simplexLP lp of - Nothing -> False - Just (z, sol) -> - (z `approxEq` (14.08 :: Double)) - && and (zipWith approxEq (V.toList sol) [1.3, 3.3]), - testCase "5" $ - let lp = - LP - { lpc = V.fromList [0], - lpA = - M.fromLists - [ [1], - [-1] - ], - lpd = V.fromList [0, 0] - } - in assertBool (show $ simplexLP lp) $ - case simplexLP lp of - Nothing -> False - Just (z, sol) -> - (z `approxEq` (0 :: Double)) - && and (zipWith approxEq (V.toList sol) [0]), - testCase "6" $ - let lp = - LP - { lpc = V.fromList [1], - lpA = - M.fromLists - [ [1], - [-1] - ], - lpd = V.fromList [5, 5] - } - in assertBool (show $ simplexLP lp) $ - case simplexLP lp of - Nothing -> False - Just (z, sol) -> - z `approxEq` (5 :: Double) - && and (zipWith approxEq (V.toList sol) [5]), - testCase "7" $ - let prog = - LinearProg - { optType = Maximize, - objective = var "x1", - constraints = - [ var "x1" ~<=~ 10 ~*~ var "b1", - var "b1" ~+~ var "b2" ~<=~ constant 1 - ] - } - (lp, _idxmap) = linearProgToLP prog - in assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, sol) -> - (z `approxEq` (10 :: Double)) - && and (zipWith (==) (V.toList sol) [1, 0, 10]), - testCase "8" $ - let prog = - LinearProg - { optType = Maximize, - objective = var "x1" ~+~ var "x2", - constraints = - [ var "x1" ~<=~ constant 10, - var "x2" ~<=~ constant 5 - ] - <> oneIsZero ("b1", "x1") ("b2", "x2") - } - (lp, _idxmap) = linearProgToLP prog - in assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, _sol) -> - and - [ z `approxEq` (15 :: Double) - ], - -- testCase "9" $ - -- let prog = - -- LinearProg - -- { optType = Maximize, - -- objective = var "x1" ~+~ var "x2", - -- constraints = - -- [ var "x1" ~<=~ constant 10, - -- var "x2" ~<=~ constant 5 - -- ] - -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) - -- } - -- (lp, idxmap) = linearProgToLP prog - -- lpe = convert lp - -- in trace - -- (unlines [show prog, show lp, show idxmap, show lpe]) - -- ( assertBool - -- (unlines [show $ simplexLP lp]) - -- $ case simplexLP lp of - -- Nothing -> False - -- Just (z, sol) -> - -- and - -- [ z `approxEq` (15 :: Double) - -- ] - -- ), - testCase "10" $ - let prog = - LinearProg - { optType = Minimize, - objective = var "R2" ~+~ var "M3", - constraints = - [ var "artifical4" ~==~ constant 1 ~+~ var "t0", - constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", - var "b_R2" ~<=~ constant 1, - var "b_M3" ~<=~ constant 1, - var "R2" ~<=~ 1000 ~*~ var "b_R2", - var "M3" ~<=~ 1000 ~*~ var "b_M3", - var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 - ] - } - (lp, _idxmap) = linearProgToLP prog - in assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, _sol) -> - and - [ z `approxEq` (0 :: Double) - ], - testCase "11" $ - let prog = - LinearProg - { optType = Minimize, - objective = var "4R" ~+~ var "5M", - constraints = - [ var "6artifical" ~==~ constant 1 ~+~ var "2t", - constant 1 ~+~ var "3num" ~==~ constant 1 ~+~ var "2t", - var "0b_R" ~<=~ constant 1, - var "1b_M" ~<=~ constant 1, - var "4R" ~<=~ 1000 ~*~ var "0b_R", - var "5M" ~<=~ 1000 ~*~ var "1b_M", - var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 - ] - } - (lp, _idxmap) = linearProgToLP prog - in assertBool - (unlines [show $ simplexLP lp]) - $ case simplexLP lp of - Nothing -> False - Just (z, _sol) -> - and - [ z `approxEq` (0 :: Double) - ] - ] - -approxEq :: (Fractional a, Ord a) => a -> a -> Bool -approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/src-testing/futhark_tests.hs b/src-testing/futhark_tests.hs index d39059dc03..18c85cf6d1 100644 --- a/src-testing/futhark_tests.hs +++ b/src-testing/futhark_tests.hs @@ -11,8 +11,6 @@ import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.ArrayLayoutTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified -import Futhark.Solve.BranchAndBoundTests qualified -import Futhark.Solve.SimplexTests qualified import Language.Futhark.PrettyTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SemanticTests qualified @@ -39,8 +37,6 @@ allTests = Futhark.Analysis.AlgSimplifyTests.tests, Language.Futhark.TypeCheckerTests.tests, Language.Futhark.SemanticTests.tests, - Futhark.Solve.SimplexTests.tests, - Futhark.Solve.BranchAndBoundTests.tests, Futhark.Optimise.ArrayLayoutTests.tests ] diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs deleted file mode 100644 index 258757113b..0000000000 --- a/src/Futhark/Solve/BranchAndBound.hs +++ /dev/null @@ -1,74 +0,0 @@ -module Futhark.Solve.BranchAndBound (branchAndBound) where - -import Data.Map qualified as M -import Data.Maybe -import Data.Set qualified as S -import Data.Vector.Unboxed (Unbox, Vector) -import Data.Vector.Unboxed qualified as V -import Futhark.Solve.LP (LP (..)) -import Futhark.Solve.Matrix -import Futhark.Solve.Simplex - -newtype Bound a = Bound (Maybe a, Maybe a) - deriving (Eq, Ord, Show) - -instance (Ord a) => Semigroup (Bound a) where - Bound (mlb1, mub1) <> Bound (mlb2, mub2) = - Bound (combine max mlb1 mlb2, combine min mub1 mub2) - where - combine _ Nothing b2 = b2 - combine _ b1 Nothing = b1 - combine c (Just b1) (Just b2) = Just $ c b1 b2 - --- | Solves an LP with the additional constraint that all solutions --- must be integral. Returns 'Nothing' if infeasible or unbounded. -branchAndBound :: - (Read a, Unbox a, RealFrac a, Show a) => - LP a -> - Maybe (a, Vector Int) -branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt - where - (zopt, mopt) = step (S.singleton mempty) (negate $ read "Infinity") Nothing - step todo zlow opt - | S.null todo = (zlow, opt) - | otherwise = - let (next, rest) = S.deleteFindMin todo - in case simplexLP (mkProblem next) of - Nothing -> step rest zlow opt - Just (z, sol) - | z <= zlow -> step rest zlow opt - | V.all isInt sol -> - step rest z (Just $ V.map round sol) - | otherwise -> - let (idx, frac) = - V.head $ V.filter (not . isInt . snd) $ V.zip (V.generate (V.length sol) id) sol - new_todo = - S.fromList $ - filter - (/= next) - [ M.insertWith (<>) idx (Bound (Nothing, Just $ fromInteger $ floor frac)) next, - M.insertWith (<>) idx (Bound (Just $ fromInteger $ ceiling frac, Nothing)) next - ] - in step (new_todo <> rest) zlow opt - - -- TODO: use isInt x = x == round x - -- requires a better 'rowEchelon' implementation for matrices - isInt x = abs (fromIntegral (round x :: Int) - x) <= 10 ^^ ((-10) :: Int) - mkProblem = - M.foldrWithKey - ( \idx bound acc -> addBound acc idx bound - ) - prob - - addBound lp idx (Bound (mlb, mub)) = - lp - { lpA = a `addRows` new_rows, - lpd = d V.++ V.fromList new_ds - } - where - (new_rows, new_ds) = - unzip $ - catMaybes - [ (V.generate (ncols a) (\i -> if i == idx then (-1) else 0),) <$> (negate <$> mlb), - (V.generate (ncols a) (\i -> if i == idx then 1 else 0),) <$> mub - ] diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs deleted file mode 100644 index 5c8f40fcd8..0000000000 --- a/src/Futhark/Solve/GLPK.hs +++ /dev/null @@ -1,60 +0,0 @@ -module Futhark.Solve.GLPK (glpk) where - -import Control.Monad -import Data.Bifunctor -import Data.LinearProgram -import Data.Map qualified as M -import Data.Maybe -import Data.Set qualified as S -import Futhark.Solve.LP qualified as F -import System.IO.Silently - -linearProgToGLPK :: (Ord v, Num a) => F.LinearProg v a -> LP v a -linearProgToGLPK prog = - LP - { direction = cOptType $ F.optType prog, - objective = cObj $ F.objective prog, - constraints = map cConstraint $ F.constraints prog, - varBounds = bounds, - varTypes = kinds - } - where - cOptType F.Maximize = Max - cOptType F.Minimize = Min - cObj = fst . cLSum - - cLSum (F.LSum m) = - ( M.mapKeys fromJust $ M.filterWithKey (\k _ -> isJust k) m, - fromMaybe 0 (m M.!? Nothing) - ) - - cConstraint (F.Constraint ctype l r) = - let (linfunc, c) = cLSum $ l F.~-~ r - bound = - case ctype of - F.Equal -> Equ (-c) - F.LessEq -> UBound (-c) - in Constr Nothing linfunc bound - - bounds = M.fromList $ (,LBound 0) <$> varList - kinds = M.fromList $ (,IntVar) <$> varList - - varList = S.toList $ F.vars prog - -glpk :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) -glpk lp = do - (output, res) <- capture $ glpk' lp - pure $ do - guard $ "PROBLEM HAS NO INTEGER FEASIBLE SOLUTION" `notElem` lines output - res - -glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) -glpk' lp - | F.isConstant (F.objective lp) -- FIXME - = - pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) - | otherwise = do - (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp - pure $ bimap truncate (fmap truncate) <$> mres - where - opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs deleted file mode 100644 index 5011ece9fb..0000000000 --- a/src/Futhark/Solve/LP.hs +++ /dev/null @@ -1,336 +0,0 @@ -module Futhark.Solve.LP - ( LP (..), - LPE (..), - convert, - normalize, - var, - constant, - cval, - bin, - or, - min, - max, - oneIsZero, - (~+~), - (~-~), - (~*~), - (!), - neg, - linearProgToLP, - linearProgToLPE, - LSum (..), - LinearProg (..), - OptType (..), - Constraint (..), - Vars (..), - CType (..), - (~==~), - (~<=~), - (~>=~), - rowEchelonLPE, - isConstant, - ) -where - -import Data.Map (Map) -import Data.Map qualified as M -import Data.Maybe -import Data.Set (Set) -import Data.Set qualified as S -import Data.Vector.Unboxed (Unbox, Vector) -import Data.Vector.Unboxed qualified as V -import Futhark.Solve.Matrix (Matrix (..)) -import Futhark.Solve.Matrix qualified as Matrix -import Futhark.Util.Pretty -import Language.Futhark.Pretty -import Prelude hiding (max, min, or) - --- | A linear program. 'LP c a d' represents the program --- --- > maximize c^T * a --- > subject to a * x <= d --- > x >= 0 --- --- The matrix 'a' is assumed to have linearly-independent rows. -data LP a = LP - { lpc :: Vector a, - lpA :: Matrix a, - lpd :: Vector a - } - deriving (Eq, Show) - --- | Equational form of a linear program. 'LPE c a d' represents the --- program --- --- > maximize c^T * a --- > subject to a * x = d --- > x >= 0 --- --- The matrix 'a' is assumed to have linearly-independent rows. -data LPE a = LPE - { pc :: Vector a, - pA :: Matrix a, - pd :: Vector a - } - deriving (Eq, Show) - -rowEchelonLPE :: (Unbox a, Fractional a, Ord a) => LPE a -> LPE a -rowEchelonLPE (LPE c a d) = - LPE c (Matrix.sliceCols (V.generate (ncols a) id) ad) (Matrix.getCol (ncols a) ad) - where - ad = - Matrix.filterRows - (V.any (Prelude./= 0)) - (Matrix.rowEchelon $ a Matrix.<|> Matrix.fromColVector d) - --- | Converts an 'LP' into an equivalent 'LPE' by introducing slack --- variables. -convert :: (Num a, Unbox a) => LP a -> LPE a -convert (LP c a d) = LPE c' a' d - where - a' = a Matrix.<|> Matrix.diagonal (V.replicate (Matrix.nrows a) 1) - c' = c V.++ V.replicate (Matrix.nrows a) 0 - --- | Linear sum of variables. -newtype LSum v a = LSum {lsum :: Map (Maybe v) a} - deriving (Show, Eq) - -instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where - pretty (LSum m) = - concatWith (surround " + ") - $ map - ( \(k, a) -> - case k of - Nothing -> pretty a - Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' - ) - $ M.toList m - -isConstant :: (Ord v) => LSum v a -> Bool -isConstant (LSum m) = M.keysSet m `S.isSubsetOf` S.singleton Nothing - -instance Functor (LSum v) where - fmap f (LSum m) = LSum $ fmap f m - -class Vars a v where - vars :: a -> Set v - -instance (Ord v) => Vars (LSum v a) v where - vars = S.fromList . catMaybes . M.keys . lsum - --- | Type of constraint -data CType = Equal | LessEq - deriving (Show, Eq) - -instance Pretty CType where - pretty Equal = "==" - pretty LessEq = "<=" - --- | A constraint for a linear program. -data Constraint v a - = Constraint CType (LSum v a) (LSum v a) - deriving (Show, Eq) - -instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where - pretty (Constraint t l r) = - pretty l <+> pretty t <+> pretty r - -instance (Ord v) => Vars (Constraint v a) v where - vars (Constraint _ l r) = vars l <> vars r - -data OptType = Maximize | Minimize - deriving (Show, Eq) - -instance Pretty OptType where - pretty Maximize = "maximize" - pretty Minimize = "minimize" - --- | A linear program. -data LinearProg v a = LinearProg - { optType :: OptType, - objective :: LSum v a, - constraints :: [Constraint v a] - } - deriving (Show, Eq) - -instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where - pretty (LinearProg opt obj cs) = - vcat - [ pretty opt, - indent 2 $ pretty obj, - "subject to", - indent 2 $ vcat $ map pretty cs - ] - -instance (Ord v) => Vars (LinearProg v a) v where - vars lp = - vars (objective lp) - <> foldMap vars (constraints lp) - -bigM :: (Num a) => a -bigM = 2 ^ (10 :: Int) - --- max{x, y} = z -max :: (Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] -max b x y z = - [ z ~>=~ x, - z ~>=~ y, - z ~<=~ x ~+~ bigM ~*~ var b, - z ~<=~ y ~+~ bigM ~*~ (constant 1 ~-~ var b) - ] - --- min{x, y} = z -min :: (Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] -min b x y z = - [ var z ~<=~ var x, - var z ~<=~ var y, - var z ~>=~ var x ~-~ bigM ~*~ (constant 1 ~-~ var b), - var z ~>=~ var y ~-~ bigM ~*~ var b - ] - -oneIsZero :: (Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] -oneIsZero (b1, x1) (b2, x2) = - mkC b1 x1 - <> mkC b2 x2 - <> [(var b1 ~+~ var b2) ~<=~ constant 1] - where - mkC b x = - [ var x ~<=~ bigM ~*~ var b - ] - -or :: (Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] -or b1 b2 c1 c2 = - mkC b1 c1 - <> mkC b2 c2 - <> [var b1 ~+~ var b2 ~<=~ constant 1] - where - mkC b (Constraint Equal l r) = - [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b), - l ~>=~ r ~-~ bigM ~*~ (constant 1 ~-~ var b) - ] - mkC b (Constraint LessEq l r) = - [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) - ] - -bin :: (Num a) => v -> Constraint v a -bin v = Constraint LessEq (var v) (constant 1) - -(~==~) :: LSum v a -> LSum v a -> Constraint v a -l ~==~ r = Constraint Equal l r - -infix 4 ~==~ - -(~<=~) :: LSum v a -> LSum v a -> Constraint v a -l ~<=~ r = Constraint LessEq l r - -infix 4 ~<=~ - -(~>=~) :: (Num a) => LSum v a -> LSum v a -> Constraint v a -l ~>=~ r = Constraint LessEq (neg l) (neg r) - -infix 4 ~>=~ - -normalize :: (Eq a, Num a) => LSum v a -> LSum v a -normalize = LSum . M.filter (/= 0) . lsum - -var :: (Num a) => v -> LSum v a -var v = LSum $ M.singleton (Just v) 1 - -constant :: a -> LSum v a -constant = LSum . M.singleton Nothing - -cval :: (Num a, Ord v) => LSum v a -> a -cval = (! Nothing) - -(~+~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a -(LSum x) ~+~ (LSum y) = LSum $ M.unionWith (+) x y - -infixl 6 ~+~ - -(~-~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a -x ~-~ y = x ~+~ neg y - -infixl 6 ~-~ - -(~*~) :: (Num a) => a -> LSum v a -> LSum v a -a ~*~ s = fmap (a *) s - -infixl 7 ~*~ - -(!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a -(LSum m) ! v = fromMaybe 0 (m M.!? v) - -neg :: (Num a) => LSum v a -> LSum v a -neg (LSum x) = LSum $ fmap negate x - --- | Converts a linear program given with a list of constraints --- into the standard form. -linearProgToLP :: - forall v a. - (Unbox a, Num a, Ord v) => - LinearProg v a -> - (LP a, Map Int v) -linearProgToLP (LinearProg otype obj cs) = - let c = mkRow $ convertObj otype obj - a = Matrix.fromVectors $ map (mkRow . fst) cs' - d = V.fromList $ map snd cs' - in (LP c a d, idxMap) - where - cs' = foldMap (convertEqCType . splitConstraint) cs - idxMap = - M.fromList $ - zip [0 ..] $ - catMaybes $ - M.keys $ - mconcat $ - map (lsum . fst) cs' - mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) - - convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] - convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] - convertEqCType (LessEq, s, a) = [(s, a)] - - splitConstraint :: Constraint v a -> (CType, LSum v a, a) - splitConstraint (Constraint ctype l r) = - let c = negate $ cval (l ~-~ r) - in (ctype, l ~-~ r ~-~ constant c, c) - - convertObj :: OptType -> LSum v a -> LSum v a - convertObj Maximize s = s - convertObj Minimize s = neg s - --- | Converts a linear program given with a list of constraints --- into the equational form. Assumes no <= constraints. -linearProgToLPE :: - forall v a. - (Unbox a, Num a, Ord v) => - LinearProg v a -> - (LPE a, Map Int v) -linearProgToLPE (LinearProg otype obj cs) = - let c = mkRow $ convertObj otype obj - a = Matrix.fromVectors $ map (mkRow . fst) cs' - d = V.fromList $ map snd cs' - in (LPE c a d, idxMap) - where - cs' = map (checkOnlyEqType . splitConstraint) cs - idxMap = - M.fromList $ - zip [0 ..] $ - catMaybes $ - M.keys $ - mconcat $ - map (lsum . fst) cs' - mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) - - splitConstraint :: Constraint v a -> (CType, LSum v a, a) - splitConstraint (Constraint ctype l r) = - let c = negate $ cval (l ~-~ r) - in (ctype, l ~-~ r ~-~ constant c, c) - - checkOnlyEqType :: (CType, LSum v a, a) -> (LSum v a, a) - checkOnlyEqType (Equal, s, a) = (s, a) - checkOnlyEqType (ctype, _, _) = error $ show ctype - - convertObj :: OptType -> LSum v a -> LSum v a - convertObj Maximize s = s - convertObj Minimize s = neg s diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs deleted file mode 100644 index 39ec16a39e..0000000000 --- a/src/Futhark/Solve/Matrix.hs +++ /dev/null @@ -1,330 +0,0 @@ -module Futhark.Solve.Matrix - ( Matrix (..), - toList, - toLists, - fromRowVector, - fromColVector, - fromVectors, - fromLists, - (@), - (!), - sliceCols, - getColM, - getCol, - setCol, - sliceRows, - getRowM, - getRow, - (<|>), - (<->), - addRow, - addRows, - imap, - generate, - identity, - diagonal, - (<.>), - (.*), - (*.), - (.+.), - (.-.), - rowEchelon, - filterRows, - deleteRow, - deleteCol, - ) -where - -import Data.List qualified as L -import Data.Map qualified as M -import Data.Vector.Unboxed (Unbox, Vector) -import Data.Vector.Unboxed qualified as V - --- A matrix represented as a 1D 'Vector'. -data Matrix a = Matrix - { elems :: Vector a, - nrows :: Int, - ncols :: Int - } - deriving (Eq) - -instance (Show a, Unbox a) => Show (Matrix a) where - show = - unlines . map show . toLists - -toList :: (Unbox a) => Matrix a -> [Vector a] -toList m = - map (\r -> V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] - -toLists :: (Unbox a) => Matrix a -> [[a]] -toLists m = - map (\r -> V.toList $ V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] - -fromRowVector :: (Unbox a) => Vector a -> Matrix a -fromRowVector v = - Matrix - { elems = v, - nrows = 1, - ncols = V.length v - } - -fromColVector :: (Unbox a) => Vector a -> Matrix a -fromColVector v = - Matrix - { elems = v, - nrows = V.length v, - ncols = 1 - } - -empty :: (Unbox a) => Matrix a -empty = Matrix mempty 0 0 - -fromVectors :: (Unbox a) => [Vector a] -> Matrix a -fromVectors [] = empty -fromVectors vs = - Matrix - { elems = V.concat vs, - nrows = length vs, - ncols = V.length $ head vs - } - -fromLists :: (Unbox a) => [[a]] -> Matrix a -fromLists xss = - Matrix - { elems = V.concat $ map V.fromList xss, - nrows = length xss, - ncols = length $ head xss - } - -class SelectCols a where - select :: Vector Int -> a -> a - (@) :: a -> Vector Int -> a - (@) = flip select - -infix 9 @ - -instance (Unbox a) => SelectCols (Vector a) where - select s v = V.map (v V.!) s - -instance (Unbox a) => SelectCols (Matrix a) where - select = sliceCols - -(!) :: (Unbox a) => Matrix a -> (Int, Int) -> a -m ! (r, c) = elems m V.! (ncols m * r + c) - -sliceCols :: (Unbox a) => Vector Int -> Matrix a -> Matrix a -sliceCols cols m = - Matrix - { elems = - V.generate (nrows m * V.length cols) $ \i -> - let col = cols V.! (i `rem` V.length cols) - row = i `div` V.length cols - in m ! (row, col), - nrows = nrows m, - ncols = V.length cols - } - -getColM :: (Unbox a) => Int -> Matrix a -> Matrix a -getColM col = sliceCols $ V.singleton col - -getCol :: (Unbox a) => Int -> Matrix a -> Vector a -getCol col = elems . getColM col - -setCol :: (Unbox a) => Int -> Vector a -> Matrix a -> Matrix a -setCol c col m = - m - { elems = - V.update_ (elems m) indices col - } - where - indices = V.generate (nrows m) $ - \r -> r * ncols m + c - -sliceRows :: (Unbox a) => Vector Int -> Matrix a -> Matrix a -sliceRows rows m = - Matrix - { elems = - V.generate (ncols m * V.length rows) $ \i -> - let row = rows V.! (i `rem` V.length rows) - col = i `div` V.length rows - in m ! (row, col), - nrows = V.length rows, - ncols = ncols m - } - -getRowM :: (Unbox a) => Int -> Matrix a -> Matrix a -getRowM row = sliceRows $ V.singleton row - -getRow :: (Unbox a) => Int -> Matrix a -> Vector a -getRow row = elems . getRowM row - -(<|>) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a -m1 <|> m2 = - generate f (nrows m1) (ncols m1 + ncols m2) - where - f r c - | c < ncols m1 = m1 ! (r, c) - | otherwise = m2 ! (r, c - ncols m1) - -(<->) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a -m1 <-> m2 = - generate f (nrows m1 + nrows m2) (ncols m1) - where - f r c - | r < nrows m1 = m1 ! (r, c) - | otherwise = m2 ! (r - nrows m1, c) - -addRow :: (Unbox a) => Matrix a -> Vector a -> Matrix a -addRow m v = - m - { elems = elems m V.++ v, - nrows = nrows m + 1 - } - -addRows :: (Unbox a) => Matrix a -> [Vector a] -> Matrix a -addRows = foldl addRow - -imap :: (Unbox a) => (Int -> Int -> a -> a) -> Matrix a -> Matrix a -imap f m = - m - { elems = V.imap g $ elems m - } - where - g i = - let r = i `div` ncols m - c = i `rem` nrows m - in f r c - -generate :: (Unbox a) => (Int -> Int -> a) -> Int -> Int -> Matrix a -generate f rows cols = - Matrix - { elems = - V.generate (rows * cols) $ \i -> - let r = i `div` cols - c = i `rem` cols - in f r c, - nrows = rows, - ncols = cols - } - -identity :: (Unbox a, Num a) => Int -> Matrix a -identity n = generate (\r c -> if r == c then 1 else 0) n n - -diagonal :: (Unbox a, Num a) => Vector a -> Matrix a -diagonal d = generate (\r c -> if r == c then d V.! r else 0) (V.length d) (V.length d) - -(<.>) :: (Unbox a, Num a) => Vector a -> Vector a -> a -v1 <.> v2 = V.sum $ V.zipWith (*) v1 v2 - -infixl 7 <.> - -(*.) :: (Unbox a, Num a) => Matrix a -> Vector a -> Vector a -m *. v = - V.generate (nrows m) $ \r -> - getRow r m <.> v - -infixl 7 *. - -(.*) :: (Unbox a, Num a) => Vector a -> Matrix a -> Vector a -v .* m = - V.generate (ncols m) $ \c -> - v <.> getCol c m - -infixl 7 .* - -(.-.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a -(.-.) = V.zipWith (-) - -infixl 6 .-. - -(.+.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a -(.+.) = V.zipWith (+) - -infixl 6 .+. - -swapRows :: (Unbox a) => Int -> Int -> Matrix a -> Matrix a -swapRows r1 r2 m = - m - { elems = - elems m `V.update` new - } - where - start1 = ncols m * r1 - start2 = ncols m * r2 - row1 = getRow r1 m - row2 = getRow r2 m - new = - V.imap (\i a -> (i + start1, a)) row2 - V.++ V.imap (\i a -> (i + start2, a)) row1 - --- todo: fix -update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a -update m upds = - generate - ( \i j -> - case M.fromList (V.toList upds) M.!? (i, j) of - Nothing -> m ! (i, j) - Just x -> x - ) - (nrows m) - (ncols m) - --- This version doesn't maintain integrality of the entries. -rowEchelon :: (Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a -rowEchelon = rowEchelon' 0 0 - where - rowEchelon' h k m@(Matrix _ nr nc) - | h < nr && k < nc = - if m ! (pivot_row, k) == 0 - then rowEchelon' h (k + 1) m - else rowEchelon' (h + 1) (k + 1) clear_rows_below - | otherwise = m - where - pivot_row = - fst $ - L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ - [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] - m' = swapRows h pivot_row m - clear_rows_below = - update m' $ - V.fromList $ - [((i, k), 0) | i <- [h + 1 .. nr - 1]] - ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) - | i <- [h + 1 .. nr - 1], - let f = m' ! (i, k) / m' ! (h, k), - j <- [k + 1 .. nc - 1] - ] - --- TODO: fix. Something's wrong here, causes huge blow-up. --- rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a --- rowEchelon = rowEchelon' 0 0 --- where --- rowEchelon' h k m@(Matrix _ nr nc) --- | h < nr && k < nc = --- if m ! (pivot_row, k) == 0 --- then rowEchelon' h (k + 1) m --- else rowEchelon' (h + 1) (k + 1) clear_rows_below --- | otherwise = m --- where --- pivot_row = --- fst $ --- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ --- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] --- m' = swapRows h pivot_row m --- clear_rows_below = --- update m' $ --- V.fromList $ --- [((i, k), 0) | i <- [h + 1 .. nr - 1]] --- ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) --- | i <- [h + 1 .. nr - 1], --- j <- [k + 1 .. nc - 1] --- ] - -filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a -filterRows p = fromVectors . filter p . toList - -deleteRow :: (Unbox a) => Int -> Matrix a -> Matrix a -deleteRow n m = sliceRows (V.generate (nrows m - 1) (\r -> if r < n then r else r + 1)) m - -deleteCol :: (Unbox a) => Int -> Matrix a -> Matrix a -deleteCol n m = sliceCols (V.generate (ncols m - 1) (\c -> if c < n then c else c + 1)) m diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs deleted file mode 100644 index 362b300038..0000000000 --- a/src/Futhark/Solve/Simplex.hs +++ /dev/null @@ -1,235 +0,0 @@ -module Futhark.Solve.Simplex - ( simplex, - simplexLP, - simplexProg, - findBasis, - ) -where - -import Data.List qualified as L -import Data.Map.Strict (Map) -import Data.Map.Strict qualified as M -import Data.Maybe -import Data.Vector.Unboxed (Unbox, Vector) -import Data.Vector.Unboxed qualified as V -import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) -import Futhark.Solve.Matrix - --- | A tableau of an equational linear program @a * x = d@ is --- --- > x @ b = p + q * x @ n --- > --------------------- --- > z = z' + r^T * x @ n --- --- where @z = c^T * x@ and @b@ (@n@) is a vector containing the --- indices of basic (nonbasic) variables. --- --- The basic feasible solution corresponding to the above tableau is --- given by @x \@ b = p@, @x \@n = 0@ with the value of the objective --- equal to @z'@. - --- | Computes @r@ as given in the tableau above. -compR :: - (Num a, Unbox a) => - LPE a -> - Matrix a -> - Vector Int -> - Vector Int -> - Vector a -compR (LPE c a _) invA_B b n = - c @ n .-. c @ b .* invA_B .* a @ n - --- | @compQEnter prob invA_B b n enter@ computes the @enter@th --- column of @q@. -compQEnter :: - (Num a, Unbox a) => - LPE a -> - Matrix a -> - Int -> - Vector a -compQEnter (LPE _ a _) invA_B enter = - V.map negate $ invA_B *. getCol enter a - --- | Computes the objective given an inversion of @a@ and a basis. -compZ :: - (Num a, Unbox a) => - LPE a -> - Matrix a -> - Vector Int -> - a -compZ (LPE c _ d) invA_B b = - c @ b .* invA_B <.> d - --- | Constructs an auxiliary equational linear program to compute the --- initial feasible basis; returns the program along with a feasible --- basis. -mkAux :: (Ord a, Unbox a, Num a) => LPE a -> (LPE a, Vector Int, Vector Int) -mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) - where - c_aux = V.replicate (ncols a) 0 V.++ V.replicate (nrows a) (-1) - d_aux = V.map abs d - a_aux = - imap (\r _ e -> if (d V.! r) < 0 then negate e else e) a - <|> identity (nrows a) - b_aux = V.generate (nrows a) (+ ncols a) - n_aux = V.generate (ncols a) id - -fixDegenerateBasis :: - (Unbox a, Ord a, Fractional a, Show a) => - LPE a -> - Int -> - LPE a -> - (Matrix a, Vector a, Vector Int, Vector Int) -> - (LPE a, Matrix a, Vector a, Vector Int, Vector Int) -fixDegenerateBasis og_prob col prob (invA_B, p, b, n) - | Just exit_idx <- mexit_idx, - V.null (elim_row exit_idx) = - let prob' = - prob - { pA = deleteRow exit_idx (pA prob), - pd = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) $ - pd prob - } - invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B - p' = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) p - b' = - V.map snd $ - V.filter ((/= exit_idx) . fst) $ - V.imap (curry id) b - in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) - | Just exit_idx <- mexit_idx, - (enter, _) <- V.head (elim_row exit_idx) = - let enter_idx = fromJust $ V.findIndex (== enter) n - exit = b V.! exit_idx - in fixDegenerateBasis og_prob col prob $ - pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) - | otherwise = - let prob' = - prob - { pc = pc og_prob, - pA = sliceCols (V.generate col id) $ pA prob, - pd = V.map abs $ pd og_prob - } - in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) - where - mexit_idx = - fst <$> V.filter ((>= col) . snd) (V.imap (curry id) b) V.!? 0 - elim_row exit_idx = - V.filter ((/= 0) . snd) $ - V.map (\j -> (j, compQEnter prob invA_B j V.! exit_idx)) $ - V.generate col id - --- | Finds an initial feasible basis for an equational linear program. --- Returns 'Nothing' if the LP has no solution. Inverts some --- equations by multiplying by -1 so it also returns a modified (but --- equivalent) equational linear program. -findBasis :: - (Unbox a, Ord a, Fractional a, Show a) => - LPE a -> - Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) -findBasis prob = do - (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) - if compZ p_aux invA_B b == 0 - then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) - else Nothing - where - (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob - invA_B_aux = identity $ V.length b_aux - --- | Solves an equational linear program. Returns 'Nothing' if the --- program is infeasible or unbounded. Otherwise returns the optimal --- value and the solution. -simplex :: - (Unbox a, Ord a, Fractional a, Show a) => - LPE a -> - Maybe (a, Vector a) -simplex lpe = do - (lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe - (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) - let z = compZ lpe' invA_B' b' - sol = - V.map snd $ - V.fromList $ - L.sortOn fst $ - V.toList $ - V.zip (b' V.++ n') (p' V.++ V.replicate (V.length n') 0) - pure (z, sol) - --- | Solves a linear program. -simplexLP :: - (Unbox a, Ord a, Fractional a, Show a) => - LP a -> - Maybe (a, Vector a) -simplexLP lp = do - (opt, sol) <- simplex lpe - pure (opt, V.take (ncols $ lpA lp) sol) - where - lpe = convert lp - -simplexProg :: - (Unbox a, Ord a, Ord v, Fractional a, Show a) => - LinearProg v a -> - Maybe (a, Map v a) -simplexProg prog = do - (z, sol) <- simplex lpe - pure (z, M.fromList $ zipWith (\i x -> (idxMap M.! i, x)) [0 ..] $ V.toList sol) - where - (lpe, idxMap) = linearProgToLPE prog - -pivot :: - (Unbox a, Fractional a) => - LPE a -> - (Matrix a, Vector a, Vector Int, Vector Int) -> - (Int, Int) -> - (Int, Int) -> - (Matrix a, Vector a, Vector Int, Vector Int) -pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = - (invA_B', p', b', n') - where - q_enter = compQEnter prob invA_B enter - b' = b V.// [(exit_idx, enter)] - n' = n V.// [(enter_idx, exit)] - e_inv_vec = - V.map - (/ abs (q_enter V.! exit_idx)) - (q_enter V.// [(exit_idx, 1)]) - genF row col = - (if row == exit_idx then 0 else invA_B ! (row, col)) - + (e_inv_vec V.! row) * invA_B ! (exit_idx, col) - invA_B' = generate genF (nrows invA_B) (ncols invA_B) - p' = p V.// [(exit_idx, 0)] .+. V.map (* (p V.! exit_idx)) e_inv_vec - --- | One step of the simplex algorithm. -step :: - (Unbox a, Ord a, Fractional a, Show a) => - LPE a -> - (Matrix a, Vector a, Vector Int, Vector Int) -> - Maybe (Matrix a, Vector a, Vector Int, Vector Int) -step prob (invA_B, p, b, n) - | Just enter_idx <- menter_idx = - let enter = n V.! enter_idx - q_enter = compQEnter prob invA_B enter - pq = - V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ - V.filter (\(_, _, q_) -> q_ < 0) $ - V.zip3 (V.generate (V.length q_enter) id) p q_enter - in if V.null pq - then Nothing - else - let exit_val = snd $ V.minimumOn snd pq - exit_cands = - V.map fst $ V.filter ((exit_val ==) . snd) pq - (exit_idx, exit) = - V.minimumOn snd $ - V.map (\i -> (i, b V.! i)) exit_cands - in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) - | otherwise = Just (invA_B, p, b, n) - where - r = compR prob invA_B b n - menter_idx = V.findIndex (> 0) r diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 34a0cfa95b..e683f66a7d 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -1698,6 +1698,18 @@ initialCtx = pure $ ValuePrim $ UnsignedValue x' ValueAD {} -> pure x -- FIXME: these do not carry signs. _ -> error $ "Cannot unsign: " <> show x + def "map" = Just $ + TermPoly Nothing $ \t -> do + t' <- evalTypeFully t + pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> + case unfoldFunType t' of + ([_, _], ret_t) + | rowshape <- typeShape $ stripArray 1 ret_t -> + toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) + _ -> + error $ + "Invalid arguments to map intrinsic:\n" + ++ unlines [prettyString t, show f, show xs] def s | "reduce" `T.isPrefixOf` s = Just $ fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index f95db4deac..f080fc5013 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -864,6 +864,16 @@ intrinsics = $ array_a Unique $ shape [m, k, l] ), + ( "map", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), + array_a Observe $ shape [n] + ] + $ RetType [] + $ array_b Unique + $ shape [n] + ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 9a75349517..3eb0674097 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -14,6 +14,7 @@ module Language.Futhark.TypeChecker.Constraints ) where +import Data.Bifunctor import Data.Loc import Data.Map qualified as M import Futhark.Util.Pretty @@ -72,6 +73,9 @@ instance Located (Reason t) where data CtTy d = CtEq (Reason (CtType d)) (TypeBase d NoUniqueness) (TypeBase d NoUniqueness) deriving (Show) +instance Functor CtTy where + fmap f (CtEq r x y) = CtEq (fmap (first f) r) (first f x) (first f y) + ctReason :: CtTy d -> Reason (CtType d) ctReason (CtEq r _ _) = r @@ -101,6 +105,12 @@ data TyVarInfo d TyVarSum Loc (M.Map Name [CtType d]) deriving (Show, Eq) +instance Functor TyVarInfo where + fmap _ (TyVarFree loc l) = TyVarFree loc l + fmap _ (TyVarPrim loc ts) = TyVarPrim loc ts + fmap f (TyVarRecord loc m) = TyVarRecord loc $ M.map (first f) m + fmap f (TyVarSum loc m) = TyVarSum loc $ M.map (map (first f)) m + prettyTyVarInfo :: (Pretty (Shape d)) => TyVarInfo d -> Doc a prettyTyVarInfo (TyVarFree _ l) = "free" <+> pretty l prettyTyVarInfo (TyVarPrim _ pts) = "∈" <+> pretty pts diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index b82de905f5..8b2f6c0be3 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -4,247 +4,15 @@ module Language.Futhark.TypeChecker.Rank ) where -import Control.Monad -import Control.Monad.Reader -import Control.Monad.State +import Control.Monad (void) import Data.Bifunctor -import Data.Functor.Identity -import Data.List qualified as L -import Data.Map (Map) import Data.Map qualified as M -import Data.Maybe import Futhark.IR.Pretty () -import Futhark.Solve.GLPK -import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) -import Futhark.Solve.LP qualified as LP -import Futhark.Util (debugTraceM) -import Futhark.Util.Pretty import Language.Futhark hiding (ScalarType) -import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints -import Language.Futhark.TypeChecker.Monad -import System.IO.Unsafe - -type LSum = LP.LSum VName Int - -type Constraint = LP.Constraint VName Int - -type LinearProg = LP.LinearProg VName Int - -class Rank a where - rank :: a -> LSum - -instance Rank VName where - rank = var - -instance Rank SComp where - rank SDim = constant 1 - rank (SVar v) = var v - -instance Rank (Shape SComp) where - rank = foldr (\d r -> rank d ~+~ r) (constant 0) . shapeDims - -instance Rank (ScalarTypeBase SComp u) where - rank Prim {} = constant 0 - rank (TypeVar _ (QualName [] v) []) = var v - rank (TypeVar {}) = constant 0 - rank (Arrow {}) = constant 0 - rank (Record {}) = constant 0 - rank (Sum {}) = constant 0 - -instance Rank (TypeBase SComp u) where - rank (Scalar t) = rank t - rank (Array _ shape t) = rank shape ~+~ rank t - -distribAndSplitArrows :: CtTy d -> [CtTy d] -distribAndSplitArrows (CtEq r t1 t2) = - splitArrows $ CtEq r (distribute t1) (distribute t2) - where - distribute :: TypeBase dim as -> TypeBase dim as - distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = - Scalar $ - Arrow - u - Unnamed - mempty - (arrayOf s ta) - (RetType rd $ distribute $ arrayOfWithAliases Nonunique s tr) - distribute t = t - - splitArrows - ( CtEq - reason - (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) - (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) - ) = - splitArrows (CtEq reason t1a t2a) ++ splitArrows (CtEq reason t1r' t2r') - where - t1r' = t1r `setUniqueness` NoUniqueness - t2r' = t2r `setUniqueness` NoUniqueness - splitArrows c = [c] - -distribAndSplitCnstrs :: CtTy d -> [CtTy d] -distribAndSplitCnstrs ct@(CtEq r t1 t2) = - ct : splitCnstrs (CtEq r (distribute1 t1) (distribute1 t2)) - where - distribute1 :: TypeBase dim as -> TypeBase dim as - distribute1 (Array u s (Record ts1)) = - Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 - distribute1 (Array u s (Sum cs)) = - Scalar $ Sum $ (fmap . fmap) (arrayOfWithAliases u s) cs - distribute1 t = t - - -- FIXME. Should check for key set equality here. - splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = - concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) - splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) = - concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2) - splitCnstrs _ = [] - -data RankState = RankState - { rankBinVars :: Map VName VName, - rankCounter :: !Int, - rankConstraints :: [Constraint], - rankObj :: LSum - } - -newtype RankM a = RankM {runRankM :: State RankState a} - deriving (Functor, Applicative, Monad, MonadState RankState) - -incCounter :: RankM Int -incCounter = do - s <- get - put s {rankCounter = rankCounter s + 1} - pure $ rankCounter s - -binVar :: VName -> RankM VName -binVar sv = do - mbv <- gets ((M.!? sv) . rankBinVars) - case mbv of - Nothing -> do - bv <- VName ("b_" <> baseName sv) <$> incCounter - modify $ \s -> - s - { rankBinVars = M.insert sv bv $ rankBinVars s, - rankConstraints = [bin bv, var bv ~<=~ var sv] <> rankConstraints s - } - pure bv - Just bv -> pure bv - -addConstraints :: [Constraint] -> RankM () -addConstraints cs = - modify $ \s -> s {rankConstraints = cs <> rankConstraints s} - -addConstraint :: Constraint -> RankM () -addConstraint = addConstraints . pure - -addObj :: SVar -> RankM () -addObj sv = - modify $ \s -> s {rankObj = rankObj s ~+~ var sv} - -addCt :: CtTy SComp -> RankM () -addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 - -addCtAM :: CtAM -> RankM () -addCtAM (CtAM _ r m f) = do - b_r <- binVar r - b_m <- binVar m - b_max <- VName "c_max" <$> incCounter - tr <- VName ("T_" <> baseName r) <$> incCounter - addConstraints [bin b_max, var b_max ~<=~ var tr] - addConstraints $ oneIsZero (b_r, r) (b_m, m) - addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) - addObj m - addObj tr - -addTyVarInfo :: TyVar -> (Int, TyVarInfo d) -> RankM () -addTyVarInfo _ (_, TyVarFree {}) = pure () -addTyVarInfo tv (_, TyVarPrim {}) = - addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarRecord {}) = - addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarSum {}) = - addConstraint $ rank tv ~==~ constant 0 - -mkLinearProg :: [CtTy SComp] -> [CtAM] -> TyVars d -> LinearProg -mkLinearProg cs cs_am tyVars = - LP.LinearProg - { optType = Minimize, - objective = rankObj finalState, - constraints = rankConstraints finalState - } - where - initState = - RankState - { rankBinVars = mempty, - rankCounter = 0, - rankConstraints = mempty, - rankObj = constant 0 - } - buildLP = do - mapM_ addCt cs - mapM_ addCtAM cs_am - mapM_ (uncurry addTyVarInfo) $ M.toList tyVars - finalState = flip execState initState $ runRankM buildLP - -ambigCheckLinearProg :: LinearProg -> (Int, Map VName Int) -> LinearProg -ambigCheckLinearProg prog (opt, ranks) = - prog - { constraints = - -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html - [ lsum (var <$> M.keys one_bins) - ~-~ lsum (var <$> M.keys zero_bins) - ~<=~ constant (fromIntegral $ length one_bins) - ~-~ constant 1, - objective prog ~==~ constant (fromIntegral opt) - ] - ++ constraints prog - } - where - -- We really need to track which variables are binary in the LinearProg - is_bin_var = ("b_" `L.isPrefixOf`) . baseString - one_bins = M.filterWithKey (\k v -> is_bin_var k && v == 1) ranks - zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks - lsum = foldr (~+~) (constant 0) - -enumerateRankSols :: LinearProg -> [Map VName Int] -enumerateRankSols prog = - take 5 $ - takeSolns $ - iterate next_sol $ - (prog,) <$> run_glpk prog - where - run_glpk = unsafePerformIO . glpk - next_sol m = do - (prog', sol') <- m - guard (fst sol' /= 0) - let prog'' = ambigCheckLinearProg prog' sol' - sol'' <- run_glpk prog'' - pure (prog'', sol'') - takeSolns [] = [] - takeSolns (Nothing : _) = [] - takeSolns (Just (_, (_, r)) : xs) = r : takeSolns xs - -solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] -solveRankILP loc prog = do - debugTraceM 3 $ - unlines - [ "## solveRankILP", - prettyString prog - ] - case enumerateRankSols prog of - [] -> typeError loc mempty "Rank ILP cannot be solved." - rs -> do - debugTraceM 3 "## rank maps" - forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> - debugTraceM 3 $ - unlines $ - "\n## rank map " <> prettyString i - : map prettyString (M.toList r) - pure rs rankAnalysis1 :: - (MonadTypeChecker m) => + (Monad m) => SrcLoc -> ([CtTy SComp], [CtAM]) -> TyVars SComp -> @@ -258,21 +26,19 @@ rankAnalysis1 :: Exp, Maybe (TypeExp Exp VName) ) -rankAnalysis1 loc (cs, cs_am) tyVars artificial params body retdecl = do - solutions <- rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl - case solutions of - [sol] -> pure sol - sols -> do - let (_, _, bodies', _) = L.unzip4 sols - typeError loc mempty $ - stack $ - [ "Rank ILP is ambiguous.", - "Choices:" - ] - ++ map pretty bodies' +rankAnalysis1 _loc (cs, _cs_am) tyVars artificial params body retdecl = + pure + ( ( map void cs, + M.map (first (const ())) artificial, + fmap (second void) tyVars + ), + params, + body, + retdecl + ) rankAnalysis :: - (MonadTypeChecker m) => + (Monad m) => SrcLoc -> ([CtTy SComp], [CtAM]) -> TyVars SComp -> @@ -287,225 +53,14 @@ rankAnalysis :: Maybe (TypeExp Exp VName) ) ] -rankAnalysis _ ([], []) tyVars artificial params body retdecl = do - (_, artificial', tyVars') <- substRankInfo ([], []) artificial tyVars mempty - pure [(([], artificial', tyVars'), params, body, retdecl)] -rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl = do - debugTraceM 3 $ - unlines - [ "##rankAnalysis", - "cs:", - unlines $ map prettyString cs, - "cs':", - unlines $ map prettyString cs' - ] - rank_maps <- solveRankILP loc (mkLinearProg cs' cs_am tyVars) - cts_tyvars' <- mapM (substRankInfo (cs, cs_am) artificial tyVars) rank_maps - let bodys = map (`updAM` body) rank_maps - params' = map ((`map` params) . updAMPat) rank_maps - retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps - pure $ L.zip4 cts_tyvars' params' bodys retdecls - where - cs' = - foldMap distribAndSplitCnstrs $ - foldMap distribAndSplitArrows cs - -type RankMap = M.Map VName Int - -substRankInfo :: - (MonadTypeChecker m) => - ([CtTy SComp], [CtAM]) -> - M.Map VName (CtType SComp) -> - TyVars SComp -> - RankMap -> - m ([CtTy ()], M.Map VName (CtType ()), TyVars ()) -substRankInfo (cs, _cs_am) artificial tyVars rankmap = do - ((cs', artificial', tyVars'), new_cs, new_tyVars) <- - runSubstT tyVars rankmap $ - (,,) - <$> traverse substRanksCt cs - <*> traverse substRanksType artificial - <*> substRanksTyVars tyVars - pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') - -runSubstT :: - (MonadTypeChecker m) => - TyVars SComp -> - RankMap -> - SubstT m a -> - m (a, [CtTy ()], TyVars ()) -runSubstT tyVars rankmap (SubstT m) = do - let env = - SubstEnv - { envTyVars = tyVars, - envRanks = rankmap - } - - s = - SubstState - { substTyVars = mempty, - substNewVars = mempty, - substNewCts = mempty - } - (a, s') <- runReaderT (runStateT m s) env - pure (a, substNewCts s', substTyVars s') - -newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) - deriving - ( Functor, - Applicative, - Monad, - MonadState SubstState, - MonadReader SubstEnv - ) - -data SubstEnv = SubstEnv - { envTyVars :: TyVars SComp, - envRanks :: RankMap - } - -data SubstState = SubstState - { substTyVars :: TyVars (), - substNewVars :: Map TyVar TyVar, - substNewCts :: [CtTy ()] - } - -instance MonadTrans SubstT where - lift = SubstT . lift . lift - -rankToShape :: (Monad m) => VName -> SubstT m (Shape ()) -rankToShape x = do - rs <- asks envRanks - pure $ Shape $ replicate (fromJust $ rs M.!? x) () - -newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar -newTyVar t = do - t' <- lift $ newTypeName (baseName t) - shape <- rankToShape t - loc <- asks ((locOf . snd . fromJust . (M.!? t)) . envTyVars) - modify $ \s -> - s - { substNewVars = M.insert t t' $ substNewVars s, - substNewCts = - CtEq - (Reason loc) - (Scalar (TypeVar mempty (QualName [] t) [])) - (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) - : substNewCts s - } - pure t' - -addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () -addRankInfo t = do - rs <- asks envRanks - if fromMaybe 0 (rs M.!? t) == 0 - then pure () - else do - new_vars <- gets substNewVars - maybe new_var (const $ pure ()) $ new_vars M.!? t - where - new_var = do - t' <- newTyVar t - old_tyvars <- asks envTyVars - let (level, tvinfo) = fromJust $ old_tyvars M.!? t - l = case tvinfo of - TyVarFree _ tvinfo_l -> tvinfo_l - _ -> Unlifted - tvinfo' <- substRanksTyVarInfo tvinfo - modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo') $ substTyVars s} - modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree (locOf tvinfo) l) $ substTyVars s} - -substRanksShape :: (Monad m) => Shape SComp -> SubstT m (Shape ()) -substRanksShape = foldM (\s d -> (s <>) <$> instDim d) mempty - where - instDim SDim = pure $ Shape [()] - instDim (SVar x) = rankToShape x - -substRanksType :: (MonadTypeChecker m) => TypeBase SComp u -> SubstT m (TypeBase () u) -substRanksType (Scalar (TypeVar vn (QualName qs x) targs)) = do - when (null qs) $ addRankInfo x - targs' <- mapM onTypeArg targs - pure $ Scalar $ TypeVar vn (QualName qs x) targs' - where - onTypeArg (TypeArgType t) = TypeArgType <$> substRanksType t - -- SVar cannot occur as argument to abstract ype. - onTypeArg (TypeArgDim _) = pure $ TypeArgDim () -substRanksType (Scalar (Arrow u p d ta (RetType retdims tr))) = do - ta' <- substRanksType ta - tr' <- substRanksType tr - pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) -substRanksType (Scalar (Record fs)) = - Scalar . Record <$> traverse substRanksType fs -substRanksType (Scalar (Sum cs)) = - Scalar . Sum <$> (traverse . traverse) substRanksType cs -substRanksType (Scalar (Prim pt)) = pure $ Scalar $ Prim pt -substRanksType (Array u shape t) = do - shape' <- substRanksShape shape - t' <- substRanksType $ Scalar t - pure $ arrayOfWithAliases u shape' t' - -substRanksCt :: (MonadTypeChecker m) => CtTy SComp -> SubstT m (CtTy ()) -substRanksCt (CtEq r t1 t2) = - CtEq - <$> traverse substRanksType r - <*> substRanksType t1 - <*> substRanksType t2 - -substRanksTyVarInfo :: (MonadTypeChecker m) => TyVarInfo SComp -> SubstT m (TyVarInfo ()) -substRanksTyVarInfo (TyVarFree loc l) = pure $ TyVarFree loc l -substRanksTyVarInfo (TyVarPrim loc ts) = pure $ TyVarPrim loc ts -substRanksTyVarInfo (TyVarRecord loc fs) = - TyVarRecord loc <$> traverse substRanksType fs -substRanksTyVarInfo (TyVarSum loc cs) = - TyVarSum loc <$> traverse (traverse substRanksType) cs - -substRanksTyVars :: (MonadTypeChecker m) => TyVars SComp -> SubstT m (TyVars ()) -substRanksTyVars = traverse $ \(lvl, tv) -> (lvl,) <$> substRanksTyVarInfo tv - -updAM :: RankMap -> Exp -> Exp -updAM rank_map e = - case e of - AppExp (Apply f args loc) res -> - let f' = updAM rank_map f - args' = fmap (bimap (fmap $ second upd) (updAM rank_map)) args - in AppExp (Apply f' args' loc) res - AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> - AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res - OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc -> - OpSectionRight - name - t - (updAM rank_map arg) - (Info (pa, t1a), Info (pb, t1b, argext, upd am)) - t2 - loc - OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc -> - OpSectionLeft - name - t - (updAM rank_map arg) - (Info (pa, t1a, argext, upd am), Info (pb, t1b)) - (ret, retext) - loc - _ -> runIdentity $ astMap mapper e - where - dimToRank (Var (QualName [] x) _ _) = - replicate (rank_map M.! x) (TupLit mempty mempty) - dimToRank e' = error $ prettyString e' - shapeToRank = Shape . foldMap dimToRank - upd (AutoMap r m f) = - AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) - mapper = identityMapper {mapOnExp = pure . updAM rank_map} - -updAMPat :: RankMap -> Pat ParamType -> Pat ParamType -updAMPat rank_map p = runIdentity $ astMap m p - where - m = identityMapper {mapOnExp = pure . updAM rank_map} - -updAMTypeExp :: - RankMap -> - TypeExp Exp VName -> - TypeExp Exp VName -updAMTypeExp rank_map te = runIdentity $ astMap m te - where - m = identityMapper {mapOnExp = pure . updAM rank_map} +rankAnalysis _loc (cs, _cs_am) tyVars artificial params body retdecl = do + pure + [ ( ( map void cs, + M.map (first (const ())) artificial, + fmap (second void) tyVars + ), + params, + body, + retdecl + ) + ] diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 4314fc0018..cf65002f83 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1214,6 +1214,14 @@ localChecks = void . check e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI (-x) t) (-x) t (loc1 <> loc2) _ -> error "Inferred type of int literal is not a number" + check e@(AppExp (BinOp (QualName [] v, _) _ (x, _) _ loc) _) + | baseName v == "==", + Array {} <- typeOf x, + baseTag v <= maxIntrinsicTag = do + warn loc $ + textwrap + "Comparing arrays with \"==\" is deprecated and will stop working in a future revision of the language." + recurse e check e = recurse e recurse = astMap identityMapper {mapOnExp = check} diff --git a/src/Language/Futhark/TypeChecker/Terms/Unsized.hs b/src/Language/Futhark/TypeChecker/Terms/Unsized.hs index 8c5f7a80cb..a0794a4948 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Unsized.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Unsized.hs @@ -233,11 +233,6 @@ newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniquen newTypeOverloaded loc name pts = tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) -newSVar :: loc -> Name -> TermM SVar -newSVar _loc desc = do - i <- incCounter - newID $ mkTypeVarName desc i - newArtificial :: u -> TypeBase SComp u -> TermM (TypeBase Size u) newArtificial u t = do v <- newID "artificial" @@ -286,12 +281,6 @@ ctEq reason t1 t2 = t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness -ctAM :: Reason (CtType SComp) -> SVar -> SVar -> Shape SComp -> TermM () -ctAM reason r m f = - modify $ \s -> s {termAM = ct : termAM s} - where - ct = CtAM reason r m f - localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -658,52 +647,23 @@ checkApplyOne :: (Shape Size, Type) -> (Maybe Exp, Shape Size, Type) -> TermM (Type, AutoMap) -checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do +checkApplyOne loc fname (_fframe, ftype) (arg, _argframe, argtype) = do (a, b) <- split ftype - r <- newSVar loc "R" - m <- newSVar loc "M" - let unit_info = Info $ Scalar $ Prim Bool - r_var = Var (QualName [] r) unit_info mempty - m_var = Var (QualName [] m) unit_info mempty - lhs = arrayOf (toShape (SVar r)) argtype - rhs = arrayOf (toShape (SVar m)) a - ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) + let lhs = argtype + rhs = a let reason = case arg of Just arg' -> ReasonApply (locOf arg) fname arg' lhs rhs Nothing -> Reason (locOf loc) ctEq reason lhs rhs - debugTraceM 3 $ - unlines - [ "## checkApplyOne", - "## fname", - prettyString fname, - "## (fframe, ftype)", - prettyString (fframe, ftype), - "## (argframe, argtype)", - prettyString (argframe, argtype), - "## r", - prettyString r, - "## m", - prettyString m, - "## lhs", - prettyString lhs, - "## rhs", - prettyString rhs, - "## ret", - prettyString $ arrayOf (toShape (SVar m)) b - ] pure - ( arrayOf (toShape (SVar m)) b, + ( b, AutoMap - { autoRep = toShape r_var, - autoMap = toShape m_var, - autoFrame = toShape m_var <> fframe + { autoRep = mempty, + autoMap = mempty, + autoFrame = mempty } ) where - toSComp (Var (QualName [] x) _ _) = SVar x - toSComp _ = error "" - toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) split (Array _u s t) = do @@ -1255,6 +1215,8 @@ doDefault :: Either [PrimType] (TypeBase () NoUniqueness) -> TermM (TypeBase () NoUniqueness) doDefault tyvars_at_toplevel v (Left pts) + | [pt] <- pts = + pure $ Scalar $ Prim pt | Signed Int32 `elem` pts = do when (v `elem` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 75ebead21d..82afb9e190 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -519,8 +519,8 @@ unionTyVars reason bcs v v_node t_node = do TyVarPrim t_loc t_pts ) -> let pts = L.intersect v_pts t_pts - in if null pts - then + in case pts of + [] -> pure $ Left ( locOf reason, @@ -530,7 +530,7 @@ unionTyVars reason bcs v v_node t_node = do "with type that must be one of" indent 2 (pretty t_pts) ) - else pure $ Right $ Just $ Unsolved $ TyVarPrim t_loc pts + _ -> pure $ Right $ Just $ Unsolved $ TyVarPrim t_loc pts (Unsolved (TyVarPrim _ v_pts), TyVarRecord {}) -> pure $ Left diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut deleted file mode 100644 index 8c1ec556c3..0000000000 --- a/tests/automap/ambiguous0.fut +++ /dev/null @@ -1,4 +0,0 @@ --- == --- error: ambiguous - -def ambig (xss : [][]i32) = i64.sum (length xss) diff --git a/tests/automap/bool1.fut b/tests/automap/bool1.fut deleted file mode 100644 index f3fe08213e..0000000000 --- a/tests/automap/bool1.fut +++ /dev/null @@ -1,6 +0,0 @@ --- == --- entry: f --- input { [true, true, false] [false, true, true] } --- output { [true, true, true] } - -def f [m] (xs: [m]bool) (ys: [m]bool) = xs || ys diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut deleted file mode 100644 index 7d77e85abb..0000000000 --- a/tests/automap/combinations.fut +++ /dev/null @@ -1,38 +0,0 @@ --- All the various ways one can imagine automapping a very simple program. - -def plus (x: i32) (y: i32) = x + y - --- == --- entry: vecint --- input { [1,2,3] } output { [3,4,5] } - -entry vecint (x: []i32) = plus x 2 - --- == --- entry: vecvec --- input { [1,2,3] } output { [2,4,6] } - -entry vecvec (x: []i32) = plus x x - --- == --- entry: matint --- input { [[1,2],[3,4]] } output { [[3,4],[5,6]] } - -entry matint (x: [][]i32) = plus x 2 - --- == --- entry: matmat --- input { [[1,2],[3,4]] } output { [[2,4],[6,8]] } - -entry matmat (x: [][]i32) = plus x x - --- == --- entry: matvec --- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } - -entry matvec (x: [][]i32) (y: []i32) = plus x y - --- == --- entry: vecvecvec --- input { [1,2,3] } output { [3,6,9] } -entry vecvecvec (x: []i32) = (\x y z -> x + y + z) x x x diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut deleted file mode 100644 index b2a173f30d..0000000000 --- a/tests/automap/equality1.fut +++ /dev/null @@ -1,23 +0,0 @@ --- == --- entry: bigger_to_smaller --- input { [[1,2],[3,4]] [1,2] } --- output { [[true, true], [false, false]] } - --- == --- entry: smaller_to_bigger --- input { [[1,2],[3,4]] [1,2] } --- output { [[true, true], [false, false]] } - --- == --- entry: smaller_to_bigger2 --- input { [[1,2],[3,4]] 1 } --- output { [[true,false],[false,false]]} - -entry bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = - xss == ys - -entry smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = - ys == xss - -entry smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = - z == xss diff --git a/tests/automap/lambda.fut b/tests/automap/lambda.fut deleted file mode 100644 index 1bb7ed26e3..0000000000 --- a/tests/automap/lambda.fut +++ /dev/null @@ -1,6 +0,0 @@ --- == --- entry: main --- random input { [10]f32 [10]f32 } - -entry main [n](xs: [n]f32) (ys: [n]f32): [n]f32 = - map2 (*) xs ys diff --git a/tests/automap/leetcode.fut b/tests/automap/leetcode.fut deleted file mode 100644 index 43a50cb2b8..0000000000 --- a/tests/automap/leetcode.fut +++ /dev/null @@ -1,4 +0,0 @@ -def outerprod f x y = map (f >-> flip map y) x -def bidd A = outerprod (==) (indices A) (indices A) -def xmat A = bidd A || reverse (bidd A) -def check_matrix (A : [][]i32) = xmat A == (A != 0) |> flatten |> and diff --git a/tests/automap/map0.fut b/tests/automap/map0.fut deleted file mode 100644 index a5ab0887ae..0000000000 --- a/tests/automap/map0.fut +++ /dev/null @@ -1,8 +0,0 @@ --- == --- entry: main --- input { [0,1,2,3] } --- output { [1,2,3,4] } - -def automap 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = f as - -entry main (x: []i32) = automap (+1) x diff --git a/tests/automap/mri-q-qr.fut b/tests/automap/mri-q-qr.fut deleted file mode 100644 index 8004f7da5d..0000000000 --- a/tests/automap/mri-q-qr.fut +++ /dev/null @@ -1,2 +0,0 @@ -def qr [numX][numK] (expArgs : [numX][numK]f32) (phiMag : [numK]f32) : [numX]f32 = - f32.sum (f32.cos expArgs * phiMag) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut deleted file mode 100644 index 270e18195a..0000000000 --- a/tests/automap/mri-q.fut +++ /dev/null @@ -1,41 +0,0 @@ --- == --- entry: main --- random input { [12]f32 [12]f32 [12]f32 [10]f32 [10]f32 [10]f32 [12]f32 [12]f32 } --- output { true } - -def main_orig [numK][numX] - (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) - (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) - (phiR: [numK]f32) (phiI: [numK]f32) - : ([numX]f32, [numX]f32) = - let phiMag = map2 (\r i -> r*r + i*i) phiR phiI - let expArgs = map3 (\x_e y_e z_e -> - map (2.0f32*f32.pi*) - (map3 (\kx_e ky_e kz_e -> - kx_e * x_e + ky_e * y_e + kz_e * z_e) - kx ky kz)) - x y z - let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs - let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs - in (qr, qi) - -def main_am [numK][numX] - (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) - (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) - (phiR: [numK]f32) (phiI: [numK]f32) - : ([numX]f32, [numX]f32) = - let phiMag = phiR * phiR + phiI * phiI - let expArgs = map3 (\x_e y_e z_e -> - 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) - x y z - let qr = f32.sum (f32.cos expArgs * phiMag) - let qi = f32.sum (f32.sin expArgs * phiMag) - in (qr, qi) - -entry main [numK][numX] - (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) - (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) - (phiR: [numK]f32) (phiI: [numK]f32) = - let (qr, qi) = main_orig kx ky kz x y z phiR phiI - let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI - in and (qr == qr_am && qi == qi_am) diff --git a/tests/automap/operator1.fut b/tests/automap/operator1.fut deleted file mode 100644 index 464a8b79c4..0000000000 --- a/tests/automap/operator1.fut +++ /dev/null @@ -1,9 +0,0 @@ --- == --- entry: main --- input { [[1,2],[3,4]] [10,20] } --- output { [[11, 22],[13, 24]] } - -def (+^) [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = xs + ys - ---entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = --- xss +^ ys diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut deleted file mode 100644 index c4c916521f..0000000000 --- a/tests/automap/optionpricing.fut +++ /dev/null @@ -1,78 +0,0 @@ --- == --- entry: sobolIndR --- random input { [12][10]i32 i32 } --- output { true } - --- == --- entry: sobolRecI --- random input { [12][10]i32 [12]i32 i32} --- output { true } - --- == --- entry: sobolReci2 --- random input { [12][10]i32 [12]i32 i32} --- output { true } - -def grayCode(x: i32): i32 = (x >> 1) ^ x - -def testBit(n: i32, ind: i32): bool = - let t = (1 << ind) in (n & t) == t - -def xorInds [num_bits] (n: i32) (dir_vs: [num_bits]i32): i32 = - let reldv_vals = map (\(dv: i32, i): i32 -> - if testBit(grayCode(n),i32.i64 i) - then dv else 0 - ) (zip (dir_vs) (iota(num_bits)) ) in - reduce (^) 0 (reldv_vals ) - - -def sobolIndI [len] (dir_vs: [len][]i32, n: i32 ): [len]i32 = - map (xorInds(n)) (dir_vs ) - -def index_of_least_significant_0(num_bits: i32, n: i32): i32 = - let (goon,k) = (true,0) in - let (_,k,_) = loop ((goon,k,n)) for i < num_bits do - if(goon) - then if (n & 1) == 1 - then (true, k+1, n>>1) - else (false,k, n ) - else (false,k, n ) - in k - -def recM [len][num_bits] (sob_dirs: [len][num_bits]i32, i: i32 ): [len]i32 = - let bit= index_of_least_significant_0(i32.i64 num_bits,i) in - map (\(row: []i32): i32 -> row[bit]) (sob_dirs ) - -def sobolIndR_orig [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = - let divisor = 2.0 ** f32.i64(num_bits) - let arri = map (xorInds n) dir_vs - in map (\x -> f32.i32(x) / divisor) arri - -def sobolRecI_orig [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = - let bit = index_of_least_significant_0(i32.i64 num_bits, x) - in map2 (\vct_row prev -> vct_row[bit] ^ prev) sob_dir_vs prev - -def sobolReci2_orig [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= - let col = recM(sob_dirs, i) - in map2 (^) prev col - -def sobolIndR_am [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = - let divisor = 2.0 ** f32.i64(num_bits) - let arri = xorInds n dir_vs - in f32.i32 arri / divisor - -def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = - let bit = index_of_least_significant_0(i32.i64 num_bits, x) - in sob_dir_vs[:,bit] ^ prev - -def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= - prev ^ recM(sob_dirs, i) - -entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): bool = - and (sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n) - -entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): bool = - and (sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x)) - -entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): bool = - and (sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i)) diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut deleted file mode 100644 index 3552990144..0000000000 --- a/tests/automap/pagerank.fut +++ /dev/null @@ -1,18 +0,0 @@ --- == --- entry: calculate_dangling_ranks --- random input { [12]f32 [12]i32} --- output { true } - -def calculate_dangling_ranks_orig [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = - let zipped = zip sizes ranks - let weights = map (\(size, rank) -> if size == 0 then rank else 0f32) zipped - let total = f32.sum weights / f32.i64 n - in map (+total) ranks - -def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = - let weights = f32.bool (sizes == 0) * ranks - let total = f32.sum weights / f32.i64 n - in ranks + total - -entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): bool = - and (calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes) diff --git a/tests/automap/project.fut b/tests/automap/project.fut deleted file mode 100644 index 2902d0565a..0000000000 --- a/tests/automap/project.fut +++ /dev/null @@ -1,9 +0,0 @@ --- == --- entry: main --- input { [1,2,3] [4,5,6] } --- output { [1,2,3,4,5,6] } - -entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = - let xsys = zip xs ys - in xsys.0 ++ xsys.1 - diff --git a/tests/automap/projsec1.fut b/tests/automap/projsec1.fut deleted file mode 100644 index 485c977bc5..0000000000 --- a/tests/automap/projsec1.fut +++ /dev/null @@ -1,9 +0,0 @@ --- == --- entry: main --- input { [1,2,3] [4,5,6] } --- output { [1,2,3,4,5,6] } - -entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = - let xsys = zip xs ys - in (.0) xsys ++ (.1) xsys - diff --git a/tests/automap/same_typevar.fut b/tests/automap/same_typevar.fut deleted file mode 100644 index 260a00b785..0000000000 --- a/tests/automap/same_typevar.fut +++ /dev/null @@ -1,16 +0,0 @@ --- == --- tags { no_wasm } --- entry: big_to_small --- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } - --- == --- entry: small_to_big --- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } - -def f 'a (x: a) (y: a) (z: a) = (x, y, z) - -entry big_to_small [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = - f xss ys z - -entry small_to_big [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = - f z ys xss diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut deleted file mode 100644 index a31ce0188e..0000000000 --- a/tests/automap/sgemm.fut +++ /dev/null @@ -1,32 +0,0 @@ --- == --- entry: main --- random input { [5][10]f32 [10][3]f32 [5][3]f32 f32 f32 } --- output { true } - -def mult_orig [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = - let dotprod xs ys = f32.sum (map2 (*) xs ys) - in map (\xs -> map (dotprod xs) (transpose yss)) xss - -def add [n][m] (xss: [n][m]f32, yss: [n][m]f32): [n][m]f32 = - map2 (map2 (+)) xss yss - -def scale [n][m] (xss: [n][m]f32, a: f32): [n][m]f32 = - map (map1 (*a)) xss - -def main_orig [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) - (alpha: f32) (beta: f32) - : [n][p]f32 = - add(scale(css,beta), scale(mult_orig(ass,bss), alpha)) - - -def mult_am [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = - f32.sum ((transpose (replicate p xss)) * (replicate n (transpose yss))) - -def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) - (alpha: f32) (beta: f32) - : [n][p]f32 = - css*beta + mult_am(ass,bss)*alpha - -entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) - (alpha: f32) (beta: f32) = - and (and (main_orig ass bss css alpha beta == main_am ass bss css alpha beta)) diff --git a/tests/automap/simple1.fut b/tests/automap/simple1.fut deleted file mode 100644 index f8833bb3b6..0000000000 --- a/tests/automap/simple1.fut +++ /dev/null @@ -1,7 +0,0 @@ --- == --- entry: main --- input { [1,2] 10 } --- output { [11, 12] } - -entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = - xs + y diff --git a/tests/automap/simple2.fut b/tests/automap/simple2.fut deleted file mode 100644 index ac57abcbe0..0000000000 --- a/tests/automap/simple2.fut +++ /dev/null @@ -1,8 +0,0 @@ --- == --- entry: main --- input { [[1,2],[3,4]] [1,1] } --- output { [[2,3],[4,5]] } - -entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = - xss + ys - diff --git a/tests/automap/simple3.fut b/tests/automap/simple3.fut deleted file mode 100644 index adc60bd43f..0000000000 --- a/tests/automap/simple3.fut +++ /dev/null @@ -1,8 +0,0 @@ --- == --- entry: main --- input { [[1,2],[3,4]] [1,1] } --- output { [[2,3],[4,5]] } - -entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = - ys + xss - diff --git a/tests/automap/simple4.fut b/tests/automap/simple4.fut deleted file mode 100644 index d94bbe4a6b..0000000000 --- a/tests/automap/simple4.fut +++ /dev/null @@ -1,8 +0,0 @@ --- == --- entry: main --- input { 3 [1,1] [[1,2],[3,4]] } --- output { [[5,6],[7,8]] } - -entry main [n] (x : i32) (ys: [n]i32) (zss : [n][n]i32) : [n][n]i32 = - x + ys + zss - diff --git a/tests/automap/simple5.fut b/tests/automap/simple5.fut deleted file mode 100644 index 46610e6567..0000000000 --- a/tests/automap/simple5.fut +++ /dev/null @@ -1,6 +0,0 @@ --- == --- input { [1,2,3] 4 } --- output { [5, 6, 7] } - -entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = - (\x y -> x + y) xs y diff --git a/tests/issue1599.fut b/tests/issue1599.fut new file mode 100644 index 0000000000..3ce47c38b1 --- /dev/null +++ b/tests/issue1599.fut @@ -0,0 +1,4 @@ +-- == +-- error: Occurs + +let bad a f = f a f diff --git a/tests/issue1926.fut b/tests/issue1926.fut index 6f79db9bb4..feaef47175 100644 --- a/tests/issue1926.fut +++ b/tests/issue1926.fut @@ -1,11 +1,12 @@ -- == --- error: cannot match value +-- error: cannot unify type with constructors type found = #found i32 | #not_found def main = let o = map (\x -> if (x > 3) then (#found x) else (#not_found)) [0, 1, 2, 3, 4] - let u = match o + let u = + match o case #found x -> x case #not_found -> -1 in u diff --git a/tests/types/inference5.fut b/tests/types/inference5.fut new file mode 100644 index 0000000000..d05b9084aa --- /dev/null +++ b/tests/types/inference5.fut @@ -0,0 +1,7 @@ +-- let should not be generalised +-- == +-- error: Cannot apply "apply" + +def main x = + let apply f x = f x + in apply (apply (i32.+) x) x From d0b890dfdc8a1333973f36218afb1b0e4a35de5a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 30 Jul 2025 14:53:01 +0200 Subject: [PATCH 293/296] Revert "Constraint-based type checking without AUTOMAP." This reverts commit 5ff3c67815e486b134528aeec6fceaa0887a0956. --- default.nix | 4 + futhark.cabal | 11 + nix/glpk-hs.nix | 23 + prelude/soacs.fut | 2 +- prelude/zip.fut | 18 +- shell.nix | 1 + .../Futhark/Solve/BranchAndBoundTests.hs | 143 +++++ src-testing/Futhark/Solve/SimplexTests.hs | 221 ++++++++ src-testing/futhark_tests.hs | 4 + src/Futhark/Solve/BranchAndBound.hs | 74 +++ src/Futhark/Solve/GLPK.hs | 60 +++ src/Futhark/Solve/LP.hs | 336 ++++++++++++ src/Futhark/Solve/Matrix.hs | 330 ++++++++++++ src/Futhark/Solve/Simplex.hs | 235 +++++++++ src/Language/Futhark/Interpreter.hs | 12 - src/Language/Futhark/Prop.hs | 10 - .../Futhark/TypeChecker/Constraints.hs | 10 - src/Language/Futhark/TypeChecker/Rank.hs | 493 +++++++++++++++++- src/Language/Futhark/TypeChecker/Terms.hs | 8 - .../Futhark/TypeChecker/Terms/Unsized.hs | 56 +- src/Language/Futhark/TypeChecker/TySolve.hs | 6 +- tests/automap/ambiguous0.fut | 4 + tests/automap/bool1.fut | 6 + tests/automap/combinations.fut | 38 ++ tests/automap/equality1.fut | 23 + tests/automap/lambda.fut | 6 + tests/automap/leetcode.fut | 4 + tests/automap/map0.fut | 8 + tests/automap/mri-q-qr.fut | 2 + tests/automap/mri-q.fut | 41 ++ tests/automap/operator1.fut | 9 + tests/automap/optionpricing.fut | 78 +++ tests/automap/pagerank.fut | 18 + tests/automap/project.fut | 9 + tests/automap/projsec1.fut | 9 + tests/automap/same_typevar.fut | 16 + tests/automap/sgemm.fut | 32 ++ tests/automap/simple1.fut | 7 + tests/automap/simple2.fut | 8 + tests/automap/simple3.fut | 8 + tests/automap/simple4.fut | 8 + tests/automap/simple5.fut | 6 + tests/issue1599.fut | 4 - tests/issue1926.fut | 5 +- tests/types/inference5.fut | 7 - 45 files changed, 2310 insertions(+), 103 deletions(-) create mode 100644 nix/glpk-hs.nix create mode 100644 src-testing/Futhark/Solve/BranchAndBoundTests.hs create mode 100644 src-testing/Futhark/Solve/SimplexTests.hs create mode 100644 src/Futhark/Solve/BranchAndBound.hs create mode 100644 src/Futhark/Solve/GLPK.hs create mode 100644 src/Futhark/Solve/LP.hs create mode 100644 src/Futhark/Solve/Matrix.hs create mode 100644 src/Futhark/Solve/Simplex.hs create mode 100644 tests/automap/ambiguous0.fut create mode 100644 tests/automap/bool1.fut create mode 100644 tests/automap/combinations.fut create mode 100644 tests/automap/equality1.fut create mode 100644 tests/automap/lambda.fut create mode 100644 tests/automap/leetcode.fut create mode 100644 tests/automap/map0.fut create mode 100644 tests/automap/mri-q-qr.fut create mode 100644 tests/automap/mri-q.fut create mode 100644 tests/automap/operator1.fut create mode 100644 tests/automap/optionpricing.fut create mode 100644 tests/automap/pagerank.fut create mode 100644 tests/automap/project.fut create mode 100644 tests/automap/projsec1.fut create mode 100644 tests/automap/same_typevar.fut create mode 100644 tests/automap/sgemm.fut create mode 100644 tests/automap/simple1.fut create mode 100644 tests/automap/simple2.fut create mode 100644 tests/automap/simple3.fut create mode 100644 tests/automap/simple4.fut create mode 100644 tests/automap/simple5.fut delete mode 100644 tests/issue1599.fut delete mode 100644 tests/types/inference5.fut diff --git a/default.nix b/default.nix index c43240b135..5821f6ca21 100644 --- a/default.nix +++ b/default.nix @@ -37,6 +37,9 @@ let gasp = haskellPackagesNew.callPackage ./nix/gasp.nix {}; + glpk-hs = + haskellPackagesNew.callPackage ./nix/glpk-hs.nix {}; + futhark = # callCabal2Nix does not do a great job at determining # which files must be included as source, which causes @@ -75,6 +78,7 @@ let "--extra-lib-dirs=${pkgs.glibc.static}/lib" "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" + "--extra-lib-dirs=${pkgs.glpk.overrideAttrs (old: { dontDisableStatic = true; })}/lib" # The ones below are due to GHC's runtime system # depending on libdw (DWARF info), which depends on # a bunch of compression algorithms. diff --git a/futhark.cabal b/futhark.cabal index 2ea03448c9..d72eb9a204 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -382,6 +382,11 @@ library Futhark.Pkg.Types Futhark.Profile Futhark.Script + Futhark.Solve.GLPK + Futhark.Solve.LP + Futhark.Solve.Matrix + Futhark.Solve.Simplex + Futhark.Solve.BranchAndBound Futhark.Test Futhark.Test.Spec Futhark.Test.Values @@ -506,6 +511,9 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 + -- remove me later + , glpk-hs + , silently executable futhark import: common @@ -541,6 +549,8 @@ library futhark-testing Futhark.Optimise.ArrayLayoutTests Futhark.Pkg.SolveTests Futhark.ProfileTests + Futhark.Solve.BranchAndBoundTests + Futhark.Solve.SimplexTests Language.Futhark.CoreTests Language.Futhark.PrettyTests Language.Futhark.ParserBenchmarks @@ -565,6 +575,7 @@ library futhark-testing , tasty-hunit , tasty-quickcheck , text + , vector >=0.12 , srcloc , regex-tdfa ^>= 1.3.2 diff --git a/nix/glpk-hs.nix b/nix/glpk-hs.nix new file mode 100644 index 0000000000..189135ed22 --- /dev/null +++ b/nix/glpk-hs.nix @@ -0,0 +1,23 @@ +{ mkDerivation, array, base, containers, deepseq, fetchgit, gasp +, glpk, lib, mtl +}: +mkDerivation { + pname = "glpk-hs"; + version = "0.8"; + src = fetchgit { + url = "https://github.com/jyp/glpk-hs.git"; + sha256 = "sha256-AY9wmmqzafpocUspQAvHjDkT4vty5J3GcSOt5qItnlo="; + rev = "1f276aa19861203ea8367dc27a6ad4c8a31c9062"; + fetchSubmodules = true; + }; + isLibrary = true; + isExecutable = true; + libraryHaskellDepends = [ array base containers deepseq gasp mtl ]; + librarySystemDepends = [ glpk ]; + executableHaskellDepends = [ + array base containers deepseq gasp mtl + ]; + description = "Comprehensive GLPK linear programming bindings"; + license = lib.licenses.bsd3; + mainProgram = "glpk-hs-example"; +} diff --git a/prelude/soacs.fut b/prelude/soacs.fut index ea3cc90614..71ee9ed5bf 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = - intrinsics.map f as + f as -- | Apply the given function to each element of a single array. -- diff --git a/prelude/zip.fut b/prelude/zip.fut index cf57c71f09..5ccbacc17b 100644 --- a/prelude/zip.fut +++ b/prelude/zip.fut @@ -6,12 +6,6 @@ -- The main reason this module exists is that we need it to define -- SOACs like `map2`. --- We need a map to define some of the zip variants, but this file is --- depended upon by soacs.fut. So we just define a quick-and-dirty --- internal one here that uses the intrinsic version. -local -def internal_map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = - intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = @@ -23,15 +17,15 @@ def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = -- | As `zip2`@term, but with one more array. def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n](a, b, c) = - internal_map (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) + (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n](a, b, c, d) = - internal_map (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) + (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n](a, b, c, d, e) = - internal_map (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) + (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. def unzip [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = @@ -43,18 +37,18 @@ def unzip2 [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = -- | As `unzip2`@term, but with one more array. def unzip3 [n] 'a 'b 'c (xs: [n](a, b, c)) : ([n]a, [n]b, [n]c) = - let (as, bcs) = unzip (internal_map (\(a, b, c) -> (a, (b, c))) xs) + let (as, bcs) = unzip ((\(a, b, c) -> (a, (b, c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. def unzip4 [n] 'a 'b 'c 'd (xs: [n](a, b, c, d)) : ([n]a, [n]b, [n]c, [n]d) = - let (as, bs, cds) = unzip3 (internal_map (\(a, b, c, d) -> (a, b, (c, d))) xs) + let (as, bs, cds) = unzip3 ((\(a, b, c, d) -> (a, b, (c, d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a, b, c, d, e)) : ([n]a, [n]b, [n]c, [n]d, [n]e) = - let (as, bs, cs, des) = unzip4 (internal_map (\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) + let (as, bs, cs, des) = unzip4 ((\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es) diff --git a/shell.nix b/shell.nix index 8f563ed276..72e1cac32e 100644 --- a/shell.nix +++ b/shell.nix @@ -43,6 +43,7 @@ pkgs.stdenv.mkDerivation { niv ispc imagemagick # needed for literate tests + glpk ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers diff --git a/src-testing/Futhark/Solve/BranchAndBoundTests.hs b/src-testing/Futhark/Solve/BranchAndBoundTests.hs new file mode 100644 index 0000000000..b7e1bfe027 --- /dev/null +++ b/src-testing/Futhark/Solve/BranchAndBoundTests.hs @@ -0,0 +1,143 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Futhark.Solve.BranchAndBoundTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.BranchAndBound +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) + +tests :: TestTree +tests = + testGroup + "BranchAndBoundTests" + [ -- testCase "1" $ + -- let lpe = + -- LPE + -- { pc = V.fromList [1, 1, 0, 0, 0], + -- pA = + -- M.fromLists + -- [ [-1, 1, 1, 0, 0], + -- [1, 0, 0, 1, 0], + -- [0, 1, 0, 0, 1] + -- ], + -- pd = V.fromList [1, 3, 2] + -- } + -- in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in branchAndBound lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in branchAndBound lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ branchAndBound lp) $ + case branchAndBound lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (11.8 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 3]), + -- testCase "5" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> oneIsZero ("b1", "x1") ("b2", "x2") + -- } + -- (lp, _idxmap) = linearProgToLP prog + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, _sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ], + -- testCase "6" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in assertBool + -- (unlines [show $ branchAndBound lp]) + -- $ case branchAndBound lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (10 :: Double) + -- ] + + testCase "10" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "R2" ~+~ var "M3", + constraints = + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ branchAndBound lp]) + $ case branchAndBound lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (0 :: Double) + ] + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/src-testing/Futhark/Solve/SimplexTests.hs b/src-testing/Futhark/Solve/SimplexTests.hs new file mode 100644 index 0000000000..c29bd10a93 --- /dev/null +++ b/src-testing/Futhark/Solve/SimplexTests.hs @@ -0,0 +1,221 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Futhark.Solve.SimplexTests + ( tests, + ) +where + +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.LP +import Futhark.Solve.Matrix qualified as M +import Futhark.Solve.Simplex +import Test.Tasty +import Test.Tasty.HUnit +import Prelude hiding (or) + +tests :: TestTree +tests = + testGroup + "SimplexTests" + [ testCase "1" $ + let lpe = + LPE + { pc = V.fromList [1, 1, 0, 0, 0], + pA = + M.fromLists + [ [-1, 1, 1, 0, 0], + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 1] + ], + pd = V.fromList [1, 3, 2] + } + in simplex lpe @?= Just (5 :: Double, V.fromList [3, 2, 2, 0, 0]), + testCase "2" $ + let lp = + LP + { lpc = V.fromList [40, 30], + lpA = + M.fromLists + [ [1, 1], + [2, 1] + ], + lpd = V.fromList [12, 16] + } + in simplexLP lp @?= Just (400 :: Double, V.fromList [4, 8]), + testCase "3" $ + let lp = + LP + { lpc = V.fromList [1, 2, 3], + lpA = + M.fromLists + [ [1, 1, 1], + [2, 1, 3] + ], + lpd = V.fromList [12, 18] + } + in simplexLP lp @?= Just (27 :: Double, V.fromList [0, 9, 3]), + testCase "4" $ + let lp = + LP + { lpc = V.fromList [5.5, 2.1], + lpA = + M.fromLists + [ [-1, 1], + [8, 2] + ], + lpd = V.fromList [2, 17] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (14.08 :: Double)) + && and (zipWith approxEq (V.toList sol) [1.3, 3.3]), + testCase "5" $ + let lp = + LP + { lpc = V.fromList [0], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [0, 0] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (0 :: Double)) + && and (zipWith approxEq (V.toList sol) [0]), + testCase "6" $ + let lp = + LP + { lpc = V.fromList [1], + lpA = + M.fromLists + [ [1], + [-1] + ], + lpd = V.fromList [5, 5] + } + in assertBool (show $ simplexLP lp) $ + case simplexLP lp of + Nothing -> False + Just (z, sol) -> + z `approxEq` (5 :: Double) + && and (zipWith approxEq (V.toList sol) [5]), + testCase "7" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1", + constraints = + [ var "x1" ~<=~ 10 ~*~ var "b1", + var "b1" ~+~ var "b2" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, sol) -> + (z `approxEq` (10 :: Double)) + && and (zipWith (==) (V.toList sol) [1, 0, 10]), + testCase "8" $ + let prog = + LinearProg + { optType = Maximize, + objective = var "x1" ~+~ var "x2", + constraints = + [ var "x1" ~<=~ constant 10, + var "x2" ~<=~ constant 5 + ] + <> oneIsZero ("b1", "x1") ("b2", "x2") + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (15 :: Double) + ], + -- testCase "9" $ + -- let prog = + -- LinearProg + -- { optType = Maximize, + -- objective = var "x1" ~+~ var "x2", + -- constraints = + -- [ var "x1" ~<=~ constant 10, + -- var "x2" ~<=~ constant 5 + -- ] + -- <> or "b1" "b2" (var "x1" ~==~ constant 0) (var "x2" ~==~ constant 0) + -- } + -- (lp, idxmap) = linearProgToLP prog + -- lpe = convert lp + -- in trace + -- (unlines [show prog, show lp, show idxmap, show lpe]) + -- ( assertBool + -- (unlines [show $ simplexLP lp]) + -- $ case simplexLP lp of + -- Nothing -> False + -- Just (z, sol) -> + -- and + -- [ z `approxEq` (15 :: Double) + -- ] + -- ), + testCase "10" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "R2" ~+~ var "M3", + constraints = + [ var "artifical4" ~==~ constant 1 ~+~ var "t0", + constant 1 ~+~ var "num1" ~==~ constant 1 ~+~ var "t0", + var "b_R2" ~<=~ constant 1, + var "b_M3" ~<=~ constant 1, + var "R2" ~<=~ 1000 ~*~ var "b_R2", + var "M3" ~<=~ 1000 ~*~ var "b_M3", + var "b_R2" ~+~ var "b_M3" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (0 :: Double) + ], + testCase "11" $ + let prog = + LinearProg + { optType = Minimize, + objective = var "4R" ~+~ var "5M", + constraints = + [ var "6artifical" ~==~ constant 1 ~+~ var "2t", + constant 1 ~+~ var "3num" ~==~ constant 1 ~+~ var "2t", + var "0b_R" ~<=~ constant 1, + var "1b_M" ~<=~ constant 1, + var "4R" ~<=~ 1000 ~*~ var "0b_R", + var "5M" ~<=~ 1000 ~*~ var "1b_M", + var "0b_R" ~+~ var "1b_M" ~<=~ constant 1 + ] + } + (lp, _idxmap) = linearProgToLP prog + in assertBool + (unlines [show $ simplexLP lp]) + $ case simplexLP lp of + Nothing -> False + Just (z, _sol) -> + and + [ z `approxEq` (0 :: Double) + ] + ] + +approxEq :: (Fractional a, Ord a) => a -> a -> Bool +approxEq x1 x2 = abs (x1 - x2) < 10 ^^ (-10 :: Int) diff --git a/src-testing/futhark_tests.hs b/src-testing/futhark_tests.hs index 18c85cf6d1..d39059dc03 100644 --- a/src-testing/futhark_tests.hs +++ b/src-testing/futhark_tests.hs @@ -11,6 +11,8 @@ import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.ArrayLayoutTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified +import Futhark.Solve.BranchAndBoundTests qualified +import Futhark.Solve.SimplexTests qualified import Language.Futhark.PrettyTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SemanticTests qualified @@ -37,6 +39,8 @@ allTests = Futhark.Analysis.AlgSimplifyTests.tests, Language.Futhark.TypeCheckerTests.tests, Language.Futhark.SemanticTests.tests, + Futhark.Solve.SimplexTests.tests, + Futhark.Solve.BranchAndBoundTests.tests, Futhark.Optimise.ArrayLayoutTests.tests ] diff --git a/src/Futhark/Solve/BranchAndBound.hs b/src/Futhark/Solve/BranchAndBound.hs new file mode 100644 index 0000000000..258757113b --- /dev/null +++ b/src/Futhark/Solve/BranchAndBound.hs @@ -0,0 +1,74 @@ +module Futhark.Solve.BranchAndBound (branchAndBound) where + +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.LP (LP (..)) +import Futhark.Solve.Matrix +import Futhark.Solve.Simplex + +newtype Bound a = Bound (Maybe a, Maybe a) + deriving (Eq, Ord, Show) + +instance (Ord a) => Semigroup (Bound a) where + Bound (mlb1, mub1) <> Bound (mlb2, mub2) = + Bound (combine max mlb1 mlb2, combine min mub1 mub2) + where + combine _ Nothing b2 = b2 + combine _ b1 Nothing = b1 + combine c (Just b1) (Just b2) = Just $ c b1 b2 + +-- | Solves an LP with the additional constraint that all solutions +-- must be integral. Returns 'Nothing' if infeasible or unbounded. +branchAndBound :: + (Read a, Unbox a, RealFrac a, Show a) => + LP a -> + Maybe (a, Vector Int) +branchAndBound prob@(LP _ a d) = (zopt,) <$> mopt + where + (zopt, mopt) = step (S.singleton mempty) (negate $ read "Infinity") Nothing + step todo zlow opt + | S.null todo = (zlow, opt) + | otherwise = + let (next, rest) = S.deleteFindMin todo + in case simplexLP (mkProblem next) of + Nothing -> step rest zlow opt + Just (z, sol) + | z <= zlow -> step rest zlow opt + | V.all isInt sol -> + step rest z (Just $ V.map round sol) + | otherwise -> + let (idx, frac) = + V.head $ V.filter (not . isInt . snd) $ V.zip (V.generate (V.length sol) id) sol + new_todo = + S.fromList $ + filter + (/= next) + [ M.insertWith (<>) idx (Bound (Nothing, Just $ fromInteger $ floor frac)) next, + M.insertWith (<>) idx (Bound (Just $ fromInteger $ ceiling frac, Nothing)) next + ] + in step (new_todo <> rest) zlow opt + + -- TODO: use isInt x = x == round x + -- requires a better 'rowEchelon' implementation for matrices + isInt x = abs (fromIntegral (round x :: Int) - x) <= 10 ^^ ((-10) :: Int) + mkProblem = + M.foldrWithKey + ( \idx bound acc -> addBound acc idx bound + ) + prob + + addBound lp idx (Bound (mlb, mub)) = + lp + { lpA = a `addRows` new_rows, + lpd = d V.++ V.fromList new_ds + } + where + (new_rows, new_ds) = + unzip $ + catMaybes + [ (V.generate (ncols a) (\i -> if i == idx then (-1) else 0),) <$> (negate <$> mlb), + (V.generate (ncols a) (\i -> if i == idx then 1 else 0),) <$> mub + ] diff --git a/src/Futhark/Solve/GLPK.hs b/src/Futhark/Solve/GLPK.hs new file mode 100644 index 0000000000..5c8f40fcd8 --- /dev/null +++ b/src/Futhark/Solve/GLPK.hs @@ -0,0 +1,60 @@ +module Futhark.Solve.GLPK (glpk) where + +import Control.Monad +import Data.Bifunctor +import Data.LinearProgram +import Data.Map qualified as M +import Data.Maybe +import Data.Set qualified as S +import Futhark.Solve.LP qualified as F +import System.IO.Silently + +linearProgToGLPK :: (Ord v, Num a) => F.LinearProg v a -> LP v a +linearProgToGLPK prog = + LP + { direction = cOptType $ F.optType prog, + objective = cObj $ F.objective prog, + constraints = map cConstraint $ F.constraints prog, + varBounds = bounds, + varTypes = kinds + } + where + cOptType F.Maximize = Max + cOptType F.Minimize = Min + cObj = fst . cLSum + + cLSum (F.LSum m) = + ( M.mapKeys fromJust $ M.filterWithKey (\k _ -> isJust k) m, + fromMaybe 0 (m M.!? Nothing) + ) + + cConstraint (F.Constraint ctype l r) = + let (linfunc, c) = cLSum $ l F.~-~ r + bound = + case ctype of + F.Equal -> Equ (-c) + F.LessEq -> UBound (-c) + in Constr Nothing linfunc bound + + bounds = M.fromList $ (,LBound 0) <$> varList + kinds = M.fromList $ (,IntVar) <$> varList + + varList = S.toList $ F.vars prog + +glpk :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) +glpk lp = do + (output, res) <- capture $ glpk' lp + pure $ do + guard $ "PROBLEM HAS NO INTEGER FEASIBLE SOLUTION" `notElem` lines output + res + +glpk' :: (Ord v, Real a) => F.LinearProg v a -> IO (Maybe (Int, M.Map v Int)) +glpk' lp + | F.isConstant (F.objective lp) -- FIXME + = + pure $ pure (0, M.fromList $ map (,0) $ S.toList $ F.vars lp) + | otherwise = do + (_, mres) <- glpSolveVars opts $ linearProgToGLPK lp + pure $ bimap truncate (fmap truncate) <$> mres + where + opts = mipDefaults {msgLev = MsgAll} diff --git a/src/Futhark/Solve/LP.hs b/src/Futhark/Solve/LP.hs new file mode 100644 index 0000000000..5011ece9fb --- /dev/null +++ b/src/Futhark/Solve/LP.hs @@ -0,0 +1,336 @@ +module Futhark.Solve.LP + ( LP (..), + LPE (..), + convert, + normalize, + var, + constant, + cval, + bin, + or, + min, + max, + oneIsZero, + (~+~), + (~-~), + (~*~), + (!), + neg, + linearProgToLP, + linearProgToLPE, + LSum (..), + LinearProg (..), + OptType (..), + Constraint (..), + Vars (..), + CType (..), + (~==~), + (~<=~), + (~>=~), + rowEchelonLPE, + isConstant, + ) +where + +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.Matrix (Matrix (..)) +import Futhark.Solve.Matrix qualified as Matrix +import Futhark.Util.Pretty +import Language.Futhark.Pretty +import Prelude hiding (max, min, or) + +-- | A linear program. 'LP c a d' represents the program +-- +-- > maximize c^T * a +-- > subject to a * x <= d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LP a = LP + { lpc :: Vector a, + lpA :: Matrix a, + lpd :: Vector a + } + deriving (Eq, Show) + +-- | Equational form of a linear program. 'LPE c a d' represents the +-- program +-- +-- > maximize c^T * a +-- > subject to a * x = d +-- > x >= 0 +-- +-- The matrix 'a' is assumed to have linearly-independent rows. +data LPE a = LPE + { pc :: Vector a, + pA :: Matrix a, + pd :: Vector a + } + deriving (Eq, Show) + +rowEchelonLPE :: (Unbox a, Fractional a, Ord a) => LPE a -> LPE a +rowEchelonLPE (LPE c a d) = + LPE c (Matrix.sliceCols (V.generate (ncols a) id) ad) (Matrix.getCol (ncols a) ad) + where + ad = + Matrix.filterRows + (V.any (Prelude./= 0)) + (Matrix.rowEchelon $ a Matrix.<|> Matrix.fromColVector d) + +-- | Converts an 'LP' into an equivalent 'LPE' by introducing slack +-- variables. +convert :: (Num a, Unbox a) => LP a -> LPE a +convert (LP c a d) = LPE c' a' d + where + a' = a Matrix.<|> Matrix.diagonal (V.replicate (Matrix.nrows a) 1) + c' = c V.++ V.replicate (Matrix.nrows a) 0 + +-- | Linear sum of variables. +newtype LSum v a = LSum {lsum :: Map (Maybe v) a} + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LSum v a) where + pretty (LSum m) = + concatWith (surround " + ") + $ map + ( \(k, a) -> + case k of + Nothing -> pretty a + Just k' -> (if a == 1 then mempty else pretty a <> "*") <> prettyName k' + ) + $ M.toList m + +isConstant :: (Ord v) => LSum v a -> Bool +isConstant (LSum m) = M.keysSet m `S.isSubsetOf` S.singleton Nothing + +instance Functor (LSum v) where + fmap f (LSum m) = LSum $ fmap f m + +class Vars a v where + vars :: a -> Set v + +instance (Ord v) => Vars (LSum v a) v where + vars = S.fromList . catMaybes . M.keys . lsum + +-- | Type of constraint +data CType = Equal | LessEq + deriving (Show, Eq) + +instance Pretty CType where + pretty Equal = "==" + pretty LessEq = "<=" + +-- | A constraint for a linear program. +data Constraint v a + = Constraint CType (LSum v a) (LSum v a) + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (Constraint v a) where + pretty (Constraint t l r) = + pretty l <+> pretty t <+> pretty r + +instance (Ord v) => Vars (Constraint v a) v where + vars (Constraint _ l r) = vars l <> vars r + +data OptType = Maximize | Minimize + deriving (Show, Eq) + +instance Pretty OptType where + pretty Maximize = "maximize" + pretty Minimize = "minimize" + +-- | A linear program. +data LinearProg v a = LinearProg + { optType :: OptType, + objective :: LSum v a, + constraints :: [Constraint v a] + } + deriving (Show, Eq) + +instance (IsName v, Pretty a, Eq a, Num a) => Pretty (LinearProg v a) where + pretty (LinearProg opt obj cs) = + vcat + [ pretty opt, + indent 2 $ pretty obj, + "subject to", + indent 2 $ vcat $ map pretty cs + ] + +instance (Ord v) => Vars (LinearProg v a) v where + vars lp = + vars (objective lp) + <> foldMap vars (constraints lp) + +bigM :: (Num a) => a +bigM = 2 ^ (10 :: Int) + +-- max{x, y} = z +max :: (Num a, Ord v) => v -> LSum v a -> LSum v a -> LSum v a -> [Constraint v a] +max b x y z = + [ z ~>=~ x, + z ~>=~ y, + z ~<=~ x ~+~ bigM ~*~ var b, + z ~<=~ y ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +-- min{x, y} = z +min :: (Num a, Ord v) => v -> v -> v -> v -> [Constraint v a] +min b x y z = + [ var z ~<=~ var x, + var z ~<=~ var y, + var z ~>=~ var x ~-~ bigM ~*~ (constant 1 ~-~ var b), + var z ~>=~ var y ~-~ bigM ~*~ var b + ] + +oneIsZero :: (Num a, Ord v) => (v, v) -> (v, v) -> [Constraint v a] +oneIsZero (b1, x1) (b2, x2) = + mkC b1 x1 + <> mkC b2 x2 + <> [(var b1 ~+~ var b2) ~<=~ constant 1] + where + mkC b x = + [ var x ~<=~ bigM ~*~ var b + ] + +or :: (Num a, Ord v) => v -> v -> Constraint v a -> Constraint v a -> [Constraint v a] +or b1 b2 c1 c2 = + mkC b1 c1 + <> mkC b2 c2 + <> [var b1 ~+~ var b2 ~<=~ constant 1] + where + mkC b (Constraint Equal l r) = + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b), + l ~>=~ r ~-~ bigM ~*~ (constant 1 ~-~ var b) + ] + mkC b (Constraint LessEq l r) = + [ l ~<=~ r ~+~ bigM ~*~ (constant 1 ~-~ var b) + ] + +bin :: (Num a) => v -> Constraint v a +bin v = Constraint LessEq (var v) (constant 1) + +(~==~) :: LSum v a -> LSum v a -> Constraint v a +l ~==~ r = Constraint Equal l r + +infix 4 ~==~ + +(~<=~) :: LSum v a -> LSum v a -> Constraint v a +l ~<=~ r = Constraint LessEq l r + +infix 4 ~<=~ + +(~>=~) :: (Num a) => LSum v a -> LSum v a -> Constraint v a +l ~>=~ r = Constraint LessEq (neg l) (neg r) + +infix 4 ~>=~ + +normalize :: (Eq a, Num a) => LSum v a -> LSum v a +normalize = LSum . M.filter (/= 0) . lsum + +var :: (Num a) => v -> LSum v a +var v = LSum $ M.singleton (Just v) 1 + +constant :: a -> LSum v a +constant = LSum . M.singleton Nothing + +cval :: (Num a, Ord v) => LSum v a -> a +cval = (! Nothing) + +(~+~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +(LSum x) ~+~ (LSum y) = LSum $ M.unionWith (+) x y + +infixl 6 ~+~ + +(~-~) :: (Ord v, Num a) => LSum v a -> LSum v a -> LSum v a +x ~-~ y = x ~+~ neg y + +infixl 6 ~-~ + +(~*~) :: (Num a) => a -> LSum v a -> LSum v a +a ~*~ s = fmap (a *) s + +infixl 7 ~*~ + +(!) :: (Num a, Ord v) => LSum v a -> Maybe v -> a +(LSum m) ! v = fromMaybe 0 (m M.!? v) + +neg :: (Num a) => LSum v a -> LSum v a +neg (LSum x) = LSum $ fmap negate x + +-- | Converts a linear program given with a list of constraints +-- into the standard form. +linearProgToLP :: + forall v a. + (Unbox a, Num a, Ord v) => + LinearProg v a -> + (LP a, Map Int v) +linearProgToLP (LinearProg otype obj cs) = + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LP c a d, idxMap) + where + cs' = foldMap (convertEqCType . splitConstraint) cs + idxMap = + M.fromList $ + zip [0 ..] $ + catMaybes $ + M.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) + + convertEqCType :: (CType, LSum v a, a) -> [(LSum v a, a)] + convertEqCType (Equal, s, a) = [(s, a), (neg s, negate a)] + convertEqCType (LessEq, s, a) = [(s, a)] + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s + +-- | Converts a linear program given with a list of constraints +-- into the equational form. Assumes no <= constraints. +linearProgToLPE :: + forall v a. + (Unbox a, Num a, Ord v) => + LinearProg v a -> + (LPE a, Map Int v) +linearProgToLPE (LinearProg otype obj cs) = + let c = mkRow $ convertObj otype obj + a = Matrix.fromVectors $ map (mkRow . fst) cs' + d = V.fromList $ map snd cs' + in (LPE c a d, idxMap) + where + cs' = map (checkOnlyEqType . splitConstraint) cs + idxMap = + M.fromList $ + zip [0 ..] $ + catMaybes $ + M.keys $ + mconcat $ + map (lsum . fst) cs' + mkRow s = V.generate (M.size idxMap) $ \i -> s ! Just (idxMap M.! i) + + splitConstraint :: Constraint v a -> (CType, LSum v a, a) + splitConstraint (Constraint ctype l r) = + let c = negate $ cval (l ~-~ r) + in (ctype, l ~-~ r ~-~ constant c, c) + + checkOnlyEqType :: (CType, LSum v a, a) -> (LSum v a, a) + checkOnlyEqType (Equal, s, a) = (s, a) + checkOnlyEqType (ctype, _, _) = error $ show ctype + + convertObj :: OptType -> LSum v a -> LSum v a + convertObj Maximize s = s + convertObj Minimize s = neg s diff --git a/src/Futhark/Solve/Matrix.hs b/src/Futhark/Solve/Matrix.hs new file mode 100644 index 0000000000..39ec16a39e --- /dev/null +++ b/src/Futhark/Solve/Matrix.hs @@ -0,0 +1,330 @@ +module Futhark.Solve.Matrix + ( Matrix (..), + toList, + toLists, + fromRowVector, + fromColVector, + fromVectors, + fromLists, + (@), + (!), + sliceCols, + getColM, + getCol, + setCol, + sliceRows, + getRowM, + getRow, + (<|>), + (<->), + addRow, + addRows, + imap, + generate, + identity, + diagonal, + (<.>), + (.*), + (*.), + (.+.), + (.-.), + rowEchelon, + filterRows, + deleteRow, + deleteCol, + ) +where + +import Data.List qualified as L +import Data.Map qualified as M +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V + +-- A matrix represented as a 1D 'Vector'. +data Matrix a = Matrix + { elems :: Vector a, + nrows :: Int, + ncols :: Int + } + deriving (Eq) + +instance (Show a, Unbox a) => Show (Matrix a) where + show = + unlines . map show . toLists + +toList :: (Unbox a) => Matrix a -> [Vector a] +toList m = + map (\r -> V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +toLists :: (Unbox a) => Matrix a -> [[a]] +toLists m = + map (\r -> V.toList $ V.slice (r * ncols m) (ncols m) (elems m)) [0 .. nrows m - 1] + +fromRowVector :: (Unbox a) => Vector a -> Matrix a +fromRowVector v = + Matrix + { elems = v, + nrows = 1, + ncols = V.length v + } + +fromColVector :: (Unbox a) => Vector a -> Matrix a +fromColVector v = + Matrix + { elems = v, + nrows = V.length v, + ncols = 1 + } + +empty :: (Unbox a) => Matrix a +empty = Matrix mempty 0 0 + +fromVectors :: (Unbox a) => [Vector a] -> Matrix a +fromVectors [] = empty +fromVectors vs = + Matrix + { elems = V.concat vs, + nrows = length vs, + ncols = V.length $ head vs + } + +fromLists :: (Unbox a) => [[a]] -> Matrix a +fromLists xss = + Matrix + { elems = V.concat $ map V.fromList xss, + nrows = length xss, + ncols = length $ head xss + } + +class SelectCols a where + select :: Vector Int -> a -> a + (@) :: a -> Vector Int -> a + (@) = flip select + +infix 9 @ + +instance (Unbox a) => SelectCols (Vector a) where + select s v = V.map (v V.!) s + +instance (Unbox a) => SelectCols (Matrix a) where + select = sliceCols + +(!) :: (Unbox a) => Matrix a -> (Int, Int) -> a +m ! (r, c) = elems m V.! (ncols m * r + c) + +sliceCols :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceCols cols m = + Matrix + { elems = + V.generate (nrows m * V.length cols) $ \i -> + let col = cols V.! (i `rem` V.length cols) + row = i `div` V.length cols + in m ! (row, col), + nrows = nrows m, + ncols = V.length cols + } + +getColM :: (Unbox a) => Int -> Matrix a -> Matrix a +getColM col = sliceCols $ V.singleton col + +getCol :: (Unbox a) => Int -> Matrix a -> Vector a +getCol col = elems . getColM col + +setCol :: (Unbox a) => Int -> Vector a -> Matrix a -> Matrix a +setCol c col m = + m + { elems = + V.update_ (elems m) indices col + } + where + indices = V.generate (nrows m) $ + \r -> r * ncols m + c + +sliceRows :: (Unbox a) => Vector Int -> Matrix a -> Matrix a +sliceRows rows m = + Matrix + { elems = + V.generate (ncols m * V.length rows) $ \i -> + let row = rows V.! (i `rem` V.length rows) + col = i `div` V.length rows + in m ! (row, col), + nrows = V.length rows, + ncols = ncols m + } + +getRowM :: (Unbox a) => Int -> Matrix a -> Matrix a +getRowM row = sliceRows $ V.singleton row + +getRow :: (Unbox a) => Int -> Matrix a -> Vector a +getRow row = elems . getRowM row + +(<|>) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <|> m2 = + generate f (nrows m1) (ncols m1 + ncols m2) + where + f r c + | c < ncols m1 = m1 ! (r, c) + | otherwise = m2 ! (r, c - ncols m1) + +(<->) :: (Unbox a) => Matrix a -> Matrix a -> Matrix a +m1 <-> m2 = + generate f (nrows m1 + nrows m2) (ncols m1) + where + f r c + | r < nrows m1 = m1 ! (r, c) + | otherwise = m2 ! (r - nrows m1, c) + +addRow :: (Unbox a) => Matrix a -> Vector a -> Matrix a +addRow m v = + m + { elems = elems m V.++ v, + nrows = nrows m + 1 + } + +addRows :: (Unbox a) => Matrix a -> [Vector a] -> Matrix a +addRows = foldl addRow + +imap :: (Unbox a) => (Int -> Int -> a -> a) -> Matrix a -> Matrix a +imap f m = + m + { elems = V.imap g $ elems m + } + where + g i = + let r = i `div` ncols m + c = i `rem` nrows m + in f r c + +generate :: (Unbox a) => (Int -> Int -> a) -> Int -> Int -> Matrix a +generate f rows cols = + Matrix + { elems = + V.generate (rows * cols) $ \i -> + let r = i `div` cols + c = i `rem` cols + in f r c, + nrows = rows, + ncols = cols + } + +identity :: (Unbox a, Num a) => Int -> Matrix a +identity n = generate (\r c -> if r == c then 1 else 0) n n + +diagonal :: (Unbox a, Num a) => Vector a -> Matrix a +diagonal d = generate (\r c -> if r == c then d V.! r else 0) (V.length d) (V.length d) + +(<.>) :: (Unbox a, Num a) => Vector a -> Vector a -> a +v1 <.> v2 = V.sum $ V.zipWith (*) v1 v2 + +infixl 7 <.> + +(*.) :: (Unbox a, Num a) => Matrix a -> Vector a -> Vector a +m *. v = + V.generate (nrows m) $ \r -> + getRow r m <.> v + +infixl 7 *. + +(.*) :: (Unbox a, Num a) => Vector a -> Matrix a -> Vector a +v .* m = + V.generate (ncols m) $ \c -> + v <.> getCol c m + +infixl 7 .* + +(.-.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.-.) = V.zipWith (-) + +infixl 6 .-. + +(.+.) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a +(.+.) = V.zipWith (+) + +infixl 6 .+. + +swapRows :: (Unbox a) => Int -> Int -> Matrix a -> Matrix a +swapRows r1 r2 m = + m + { elems = + elems m `V.update` new + } + where + start1 = ncols m * r1 + start2 = ncols m * r2 + row1 = getRow r1 m + row2 = getRow r2 m + new = + V.imap (\i a -> (i + start1, a)) row2 + V.++ V.imap (\i a -> (i + start2, a)) row1 + +-- todo: fix +update :: (Unbox a) => Matrix a -> Vector ((Int, Int), a) -> Matrix a +update m upds = + generate + ( \i j -> + case M.fromList (V.toList upds) M.!? (i, j) of + Nothing -> m ! (i, j) + Just x -> x + ) + (nrows m) + (ncols m) + +-- This version doesn't maintain integrality of the entries. +rowEchelon :: (Fractional a, Unbox a, Ord a) => Matrix a -> Matrix a +rowEchelon = rowEchelon' 0 0 + where + rowEchelon' h k m@(Matrix _ nr nc) + | h < nr && k < nc = + if m ! (pivot_row, k) == 0 + then rowEchelon' h (k + 1) m + else rowEchelon' (h + 1) (k + 1) clear_rows_below + | otherwise = m + where + pivot_row = + fst $ + L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ + [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] + m' = swapRows h pivot_row m + clear_rows_below = + update m' $ + V.fromList $ + [((i, k), 0) | i <- [h + 1 .. nr - 1]] + ++ [ ((i, j), m' ! (i, j) - (m' ! (h, j)) * f) + | i <- [h + 1 .. nr - 1], + let f = m' ! (i, k) / m' ! (h, k), + j <- [k + 1 .. nc - 1] + ] + +-- TODO: fix. Something's wrong here, causes huge blow-up. +-- rowEchelon :: (Num a, Unbox a, Ord a) => Matrix a -> Matrix a +-- rowEchelon = rowEchelon' 0 0 +-- where +-- rowEchelon' h k m@(Matrix _ nr nc) +-- | h < nr && k < nc = +-- if m ! (pivot_row, k) == 0 +-- then rowEchelon' h (k + 1) m +-- else rowEchelon' (h + 1) (k + 1) clear_rows_below +-- | otherwise = m +-- where +-- pivot_row = +-- fst $ +-- L.maximumBy (\(_, x) (_, y) -> x `compare` y) $ +-- [(r, abs (m ! (r, k))) | r <- [h .. nr - 1]] +-- m' = swapRows h pivot_row m +-- clear_rows_below = +-- update m' $ +-- V.fromList $ +-- [((i, k), 0) | i <- [h + 1 .. nr - 1]] +-- ++ [ ((i, j), (m' ! (h, k)) * (m' ! (i, j)) - (m' ! (h, j)) * (m' ! (i, k))) +-- | i <- [h + 1 .. nr - 1], +-- j <- [k + 1 .. nc - 1] +-- ] + +filterRows :: (Unbox a) => (Vector a -> Bool) -> Matrix a -> Matrix a +filterRows p = fromVectors . filter p . toList + +deleteRow :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteRow n m = sliceRows (V.generate (nrows m - 1) (\r -> if r < n then r else r + 1)) m + +deleteCol :: (Unbox a) => Int -> Matrix a -> Matrix a +deleteCol n m = sliceCols (V.generate (ncols m - 1) (\c -> if c < n then c else c + 1)) m diff --git a/src/Futhark/Solve/Simplex.hs b/src/Futhark/Solve/Simplex.hs new file mode 100644 index 0000000000..362b300038 --- /dev/null +++ b/src/Futhark/Solve/Simplex.hs @@ -0,0 +1,235 @@ +module Futhark.Solve.Simplex + ( simplex, + simplexLP, + simplexProg, + findBasis, + ) +where + +import Data.List qualified as L +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as M +import Data.Maybe +import Data.Vector.Unboxed (Unbox, Vector) +import Data.Vector.Unboxed qualified as V +import Futhark.Solve.LP (LP (..), LPE (..), LinearProg (..), convert, linearProgToLPE, rowEchelonLPE) +import Futhark.Solve.Matrix + +-- | A tableau of an equational linear program @a * x = d@ is +-- +-- > x @ b = p + q * x @ n +-- > --------------------- +-- > z = z' + r^T * x @ n +-- +-- where @z = c^T * x@ and @b@ (@n@) is a vector containing the +-- indices of basic (nonbasic) variables. +-- +-- The basic feasible solution corresponding to the above tableau is +-- given by @x \@ b = p@, @x \@n = 0@ with the value of the objective +-- equal to @z'@. + +-- | Computes @r@ as given in the tableau above. +compR :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + Vector Int -> + Vector a +compR (LPE c a _) invA_B b n = + c @ n .-. c @ b .* invA_B .* a @ n + +-- | @compQEnter prob invA_B b n enter@ computes the @enter@th +-- column of @q@. +compQEnter :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Int -> + Vector a +compQEnter (LPE _ a _) invA_B enter = + V.map negate $ invA_B *. getCol enter a + +-- | Computes the objective given an inversion of @a@ and a basis. +compZ :: + (Num a, Unbox a) => + LPE a -> + Matrix a -> + Vector Int -> + a +compZ (LPE c _ d) invA_B b = + c @ b .* invA_B <.> d + +-- | Constructs an auxiliary equational linear program to compute the +-- initial feasible basis; returns the program along with a feasible +-- basis. +mkAux :: (Ord a, Unbox a, Num a) => LPE a -> (LPE a, Vector Int, Vector Int) +mkAux (LPE _ a d) = (LPE c_aux a_aux d_aux, b_aux, n_aux) + where + c_aux = V.replicate (ncols a) 0 V.++ V.replicate (nrows a) (-1) + d_aux = V.map abs d + a_aux = + imap (\r _ e -> if (d V.! r) < 0 then negate e else e) a + <|> identity (nrows a) + b_aux = V.generate (nrows a) (+ ncols a) + n_aux = V.generate (ncols a) id + +fixDegenerateBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Int -> + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +fixDegenerateBasis og_prob col prob (invA_B, p, b, n) + | Just exit_idx <- mexit_idx, + V.null (elim_row exit_idx) = + let prob' = + prob + { pA = deleteRow exit_idx (pA prob), + pd = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) $ + pd prob + } + invA_B' = deleteRow exit_idx $ deleteCol exit_idx invA_B + p' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) p + b' = + V.map snd $ + V.filter ((/= exit_idx) . fst) $ + V.imap (curry id) b + in fixDegenerateBasis og_prob col prob' (invA_B', p', b', n) + | Just exit_idx <- mexit_idx, + (enter, _) <- V.head (elim_row exit_idx) = + let enter_idx = fromJust $ V.findIndex (== enter) n + exit = b V.! exit_idx + in fixDegenerateBasis og_prob col prob $ + pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = + let prob' = + prob + { pc = pc og_prob, + pA = sliceCols (V.generate col id) $ pA prob, + pd = V.map abs $ pd og_prob + } + in (prob', invA_B, p, V.filter (< col) b, V.filter (< col) n) + where + mexit_idx = + fst <$> V.filter ((>= col) . snd) (V.imap (curry id) b) V.!? 0 + elim_row exit_idx = + V.filter ((/= 0) . snd) $ + V.map (\j -> (j, compQEnter prob invA_B j V.! exit_idx)) $ + V.generate col id + +-- | Finds an initial feasible basis for an equational linear program. +-- Returns 'Nothing' if the LP has no solution. Inverts some +-- equations by multiplying by -1 so it also returns a modified (but +-- equivalent) equational linear program. +findBasis :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (LPE a, Matrix a, Vector a, Vector Int, Vector Int) +findBasis prob = do + (invA_B, p, b, n) <- step p_aux (invA_B_aux, d_aux, b_aux, n_aux) + if compZ p_aux invA_B b == 0 + then Just $ fixDegenerateBasis prob (ncols $ pA prob) p_aux (invA_B, p, b, n) + else Nothing + where + (p_aux@(LPE _ _ d_aux), b_aux, n_aux) = mkAux prob + invA_B_aux = identity $ V.length b_aux + +-- | Solves an equational linear program. Returns 'Nothing' if the +-- program is infeasible or unbounded. Otherwise returns the optimal +-- value and the solution. +simplex :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + Maybe (a, Vector a) +simplex lpe = do + (lpe', invA_B, p, b, n) <- findBasis $ rowEchelonLPE lpe + (invA_B', p', b', n') <- step lpe' (invA_B, p, b, n) + let z = compZ lpe' invA_B' b' + sol = + V.map snd $ + V.fromList $ + L.sortOn fst $ + V.toList $ + V.zip (b' V.++ n') (p' V.++ V.replicate (V.length n') 0) + pure (z, sol) + +-- | Solves a linear program. +simplexLP :: + (Unbox a, Ord a, Fractional a, Show a) => + LP a -> + Maybe (a, Vector a) +simplexLP lp = do + (opt, sol) <- simplex lpe + pure (opt, V.take (ncols $ lpA lp) sol) + where + lpe = convert lp + +simplexProg :: + (Unbox a, Ord a, Ord v, Fractional a, Show a) => + LinearProg v a -> + Maybe (a, Map v a) +simplexProg prog = do + (z, sol) <- simplex lpe + pure (z, M.fromList $ zipWith (\i x -> (idxMap M.! i, x)) [0 ..] $ V.toList sol) + where + (lpe, idxMap) = linearProgToLPE prog + +pivot :: + (Unbox a, Fractional a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + (Int, Int) -> + (Int, Int) -> + (Matrix a, Vector a, Vector Int, Vector Int) +pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) = + (invA_B', p', b', n') + where + q_enter = compQEnter prob invA_B enter + b' = b V.// [(exit_idx, enter)] + n' = n V.// [(enter_idx, exit)] + e_inv_vec = + V.map + (/ abs (q_enter V.! exit_idx)) + (q_enter V.// [(exit_idx, 1)]) + genF row col = + (if row == exit_idx then 0 else invA_B ! (row, col)) + + (e_inv_vec V.! row) * invA_B ! (exit_idx, col) + invA_B' = generate genF (nrows invA_B) (ncols invA_B) + p' = p V.// [(exit_idx, 0)] .+. V.map (* (p V.! exit_idx)) e_inv_vec + +-- | One step of the simplex algorithm. +step :: + (Unbox a, Ord a, Fractional a, Show a) => + LPE a -> + (Matrix a, Vector a, Vector Int, Vector Int) -> + Maybe (Matrix a, Vector a, Vector Int, Vector Int) +step prob (invA_B, p, b, n) + | Just enter_idx <- menter_idx = + let enter = n V.! enter_idx + q_enter = compQEnter prob invA_B enter + pq = + V.map (\(i, p_', q_) -> (i, -(p_' / q_))) $ + V.filter (\(_, _, q_) -> q_ < 0) $ + V.zip3 (V.generate (V.length q_enter) id) p q_enter + in if V.null pq + then Nothing + else + let exit_val = snd $ V.minimumOn snd pq + exit_cands = + V.map fst $ V.filter ((exit_val ==) . snd) pq + (exit_idx, exit) = + V.minimumOn snd $ + V.map (\i -> (i, b V.! i)) exit_cands + in step prob $ pivot prob (invA_B, p, b, n) (enter_idx, enter) (exit_idx, exit) + | otherwise = Just (invA_B, p, b, n) + where + r = compR prob invA_B b n + menter_idx = V.findIndex (> 0) r diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index e683f66a7d..34a0cfa95b 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -1698,18 +1698,6 @@ initialCtx = pure $ ValuePrim $ UnsignedValue x' ValueAD {} -> pure x -- FIXME: these do not carry signs. _ -> error $ "Cannot unsign: " <> show x - def "map" = Just $ - TermPoly Nothing $ \t -> do - t' <- evalTypeFully t - pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> - case unfoldFunType t' of - ([_, _], ret_t) - | rowshape <- typeShape $ stripArray 1 ret_t -> - toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) - _ -> - error $ - "Invalid arguments to map intrinsic:\n" - ++ unlines [prettyString t, show f, show xs] def s | "reduce" `T.isPrefixOf` s = Just $ fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index f080fc5013..f95db4deac 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -864,16 +864,6 @@ intrinsics = $ array_a Unique $ shape [m, k, l] ), - ( "map", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), - array_a Observe $ shape [n] - ] - $ RetType [] - $ array_b Unique - $ shape [n] - ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 3eb0674097..9a75349517 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -14,7 +14,6 @@ module Language.Futhark.TypeChecker.Constraints ) where -import Data.Bifunctor import Data.Loc import Data.Map qualified as M import Futhark.Util.Pretty @@ -73,9 +72,6 @@ instance Located (Reason t) where data CtTy d = CtEq (Reason (CtType d)) (TypeBase d NoUniqueness) (TypeBase d NoUniqueness) deriving (Show) -instance Functor CtTy where - fmap f (CtEq r x y) = CtEq (fmap (first f) r) (first f x) (first f y) - ctReason :: CtTy d -> Reason (CtType d) ctReason (CtEq r _ _) = r @@ -105,12 +101,6 @@ data TyVarInfo d TyVarSum Loc (M.Map Name [CtType d]) deriving (Show, Eq) -instance Functor TyVarInfo where - fmap _ (TyVarFree loc l) = TyVarFree loc l - fmap _ (TyVarPrim loc ts) = TyVarPrim loc ts - fmap f (TyVarRecord loc m) = TyVarRecord loc $ M.map (first f) m - fmap f (TyVarSum loc m) = TyVarSum loc $ M.map (map (first f)) m - prettyTyVarInfo :: (Pretty (Shape d)) => TyVarInfo d -> Doc a prettyTyVarInfo (TyVarFree _ l) = "free" <+> pretty l prettyTyVarInfo (TyVarPrim _ pts) = "∈" <+> pretty pts diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 8b2f6c0be3..b82de905f5 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -4,15 +4,247 @@ module Language.Futhark.TypeChecker.Rank ) where -import Control.Monad (void) +import Control.Monad +import Control.Monad.Reader +import Control.Monad.State import Data.Bifunctor +import Data.Functor.Identity +import Data.List qualified as L +import Data.Map (Map) import Data.Map qualified as M +import Data.Maybe import Futhark.IR.Pretty () +import Futhark.Solve.GLPK +import Futhark.Solve.LP hiding (Constraint, LSum, LinearProg) +import Futhark.Solve.LP qualified as LP +import Futhark.Util (debugTraceM) +import Futhark.Util.Pretty import Language.Futhark hiding (ScalarType) +import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Constraints +import Language.Futhark.TypeChecker.Monad +import System.IO.Unsafe + +type LSum = LP.LSum VName Int + +type Constraint = LP.Constraint VName Int + +type LinearProg = LP.LinearProg VName Int + +class Rank a where + rank :: a -> LSum + +instance Rank VName where + rank = var + +instance Rank SComp where + rank SDim = constant 1 + rank (SVar v) = var v + +instance Rank (Shape SComp) where + rank = foldr (\d r -> rank d ~+~ r) (constant 0) . shapeDims + +instance Rank (ScalarTypeBase SComp u) where + rank Prim {} = constant 0 + rank (TypeVar _ (QualName [] v) []) = var v + rank (TypeVar {}) = constant 0 + rank (Arrow {}) = constant 0 + rank (Record {}) = constant 0 + rank (Sum {}) = constant 0 + +instance Rank (TypeBase SComp u) where + rank (Scalar t) = rank t + rank (Array _ shape t) = rank shape ~+~ rank t + +distribAndSplitArrows :: CtTy d -> [CtTy d] +distribAndSplitArrows (CtEq r t1 t2) = + splitArrows $ CtEq r (distribute t1) (distribute t2) + where + distribute :: TypeBase dim as -> TypeBase dim as + distribute (Array u s (Arrow _ _ _ ta (RetType rd tr))) = + Scalar $ + Arrow + u + Unnamed + mempty + (arrayOf s ta) + (RetType rd $ distribute $ arrayOfWithAliases Nonunique s tr) + distribute t = t + + splitArrows + ( CtEq + reason + (Scalar (Arrow _ _ _ t1a (RetType _ t1r))) + (Scalar (Arrow _ _ _ t2a (RetType _ t2r))) + ) = + splitArrows (CtEq reason t1a t2a) ++ splitArrows (CtEq reason t1r' t2r') + where + t1r' = t1r `setUniqueness` NoUniqueness + t2r' = t2r `setUniqueness` NoUniqueness + splitArrows c = [c] + +distribAndSplitCnstrs :: CtTy d -> [CtTy d] +distribAndSplitCnstrs ct@(CtEq r t1 t2) = + ct : splitCnstrs (CtEq r (distribute1 t1) (distribute1 t2)) + where + distribute1 :: TypeBase dim as -> TypeBase dim as + distribute1 (Array u s (Record ts1)) = + Scalar $ Record $ fmap (arrayOfWithAliases u s) ts1 + distribute1 (Array u s (Sum cs)) = + Scalar $ Sum $ (fmap . fmap) (arrayOfWithAliases u s) cs + distribute1 t = t + + -- FIXME. Should check for key set equality here. + splitCnstrs (CtEq reason (Scalar (Record ts1)) (Scalar (Record ts2))) = + concat $ zipWith (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems ts1) (M.elems ts2) + splitCnstrs (CtEq reason (Scalar (Sum cs1)) (Scalar (Sum cs2))) = + concat $ concat $ (zipWith . zipWith) (\x y -> distribAndSplitCnstrs $ CtEq reason x y) (M.elems cs1) (M.elems cs2) + splitCnstrs _ = [] + +data RankState = RankState + { rankBinVars :: Map VName VName, + rankCounter :: !Int, + rankConstraints :: [Constraint], + rankObj :: LSum + } + +newtype RankM a = RankM {runRankM :: State RankState a} + deriving (Functor, Applicative, Monad, MonadState RankState) + +incCounter :: RankM Int +incCounter = do + s <- get + put s {rankCounter = rankCounter s + 1} + pure $ rankCounter s + +binVar :: VName -> RankM VName +binVar sv = do + mbv <- gets ((M.!? sv) . rankBinVars) + case mbv of + Nothing -> do + bv <- VName ("b_" <> baseName sv) <$> incCounter + modify $ \s -> + s + { rankBinVars = M.insert sv bv $ rankBinVars s, + rankConstraints = [bin bv, var bv ~<=~ var sv] <> rankConstraints s + } + pure bv + Just bv -> pure bv + +addConstraints :: [Constraint] -> RankM () +addConstraints cs = + modify $ \s -> s {rankConstraints = cs <> rankConstraints s} + +addConstraint :: Constraint -> RankM () +addConstraint = addConstraints . pure + +addObj :: SVar -> RankM () +addObj sv = + modify $ \s -> s {rankObj = rankObj s ~+~ var sv} + +addCt :: CtTy SComp -> RankM () +addCt (CtEq _ t1 t2) = addConstraint $ rank t1 ~==~ rank t2 + +addCtAM :: CtAM -> RankM () +addCtAM (CtAM _ r m f) = do + b_r <- binVar r + b_m <- binVar m + b_max <- VName "c_max" <$> incCounter + tr <- VName ("T_" <> baseName r) <$> incCounter + addConstraints [bin b_max, var b_max ~<=~ var tr] + addConstraints $ oneIsZero (b_r, r) (b_m, m) + addConstraints $ LP.max b_max (constant 0) (rank r ~-~ rank f) (var tr) + addObj m + addObj tr + +addTyVarInfo :: TyVar -> (Int, TyVarInfo d) -> RankM () +addTyVarInfo _ (_, TyVarFree {}) = pure () +addTyVarInfo tv (_, TyVarPrim {}) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarRecord {}) = + addConstraint $ rank tv ~==~ constant 0 +addTyVarInfo tv (_, TyVarSum {}) = + addConstraint $ rank tv ~==~ constant 0 + +mkLinearProg :: [CtTy SComp] -> [CtAM] -> TyVars d -> LinearProg +mkLinearProg cs cs_am tyVars = + LP.LinearProg + { optType = Minimize, + objective = rankObj finalState, + constraints = rankConstraints finalState + } + where + initState = + RankState + { rankBinVars = mempty, + rankCounter = 0, + rankConstraints = mempty, + rankObj = constant 0 + } + buildLP = do + mapM_ addCt cs + mapM_ addCtAM cs_am + mapM_ (uncurry addTyVarInfo) $ M.toList tyVars + finalState = flip execState initState $ runRankM buildLP + +ambigCheckLinearProg :: LinearProg -> (Int, Map VName Int) -> LinearProg +ambigCheckLinearProg prog (opt, ranks) = + prog + { constraints = + -- https://yetanothermathprogrammingconsultant.blogspot.com/2011/10/integer-cuts.html + [ lsum (var <$> M.keys one_bins) + ~-~ lsum (var <$> M.keys zero_bins) + ~<=~ constant (fromIntegral $ length one_bins) + ~-~ constant 1, + objective prog ~==~ constant (fromIntegral opt) + ] + ++ constraints prog + } + where + -- We really need to track which variables are binary in the LinearProg + is_bin_var = ("b_" `L.isPrefixOf`) . baseString + one_bins = M.filterWithKey (\k v -> is_bin_var k && v == 1) ranks + zero_bins = M.filterWithKey (\k v -> is_bin_var k && v == 0) ranks + lsum = foldr (~+~) (constant 0) + +enumerateRankSols :: LinearProg -> [Map VName Int] +enumerateRankSols prog = + take 5 $ + takeSolns $ + iterate next_sol $ + (prog,) <$> run_glpk prog + where + run_glpk = unsafePerformIO . glpk + next_sol m = do + (prog', sol') <- m + guard (fst sol' /= 0) + let prog'' = ambigCheckLinearProg prog' sol' + sol'' <- run_glpk prog'' + pure (prog'', sol'') + takeSolns [] = [] + takeSolns (Nothing : _) = [] + takeSolns (Just (_, (_, r)) : xs) = r : takeSolns xs + +solveRankILP :: (MonadTypeChecker m) => SrcLoc -> LinearProg -> m [Map VName Int] +solveRankILP loc prog = do + debugTraceM 3 $ + unlines + [ "## solveRankILP", + prettyString prog + ] + case enumerateRankSols prog of + [] -> typeError loc mempty "Rank ILP cannot be solved." + rs -> do + debugTraceM 3 "## rank maps" + forM_ (zip [0 :: Int ..] rs) $ \(i, r) -> + debugTraceM 3 $ + unlines $ + "\n## rank map " <> prettyString i + : map prettyString (M.toList r) + pure rs rankAnalysis1 :: - (Monad m) => + (MonadTypeChecker m) => SrcLoc -> ([CtTy SComp], [CtAM]) -> TyVars SComp -> @@ -26,19 +258,21 @@ rankAnalysis1 :: Exp, Maybe (TypeExp Exp VName) ) -rankAnalysis1 _loc (cs, _cs_am) tyVars artificial params body retdecl = - pure - ( ( map void cs, - M.map (first (const ())) artificial, - fmap (second void) tyVars - ), - params, - body, - retdecl - ) +rankAnalysis1 loc (cs, cs_am) tyVars artificial params body retdecl = do + solutions <- rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl + case solutions of + [sol] -> pure sol + sols -> do + let (_, _, bodies', _) = L.unzip4 sols + typeError loc mempty $ + stack $ + [ "Rank ILP is ambiguous.", + "Choices:" + ] + ++ map pretty bodies' rankAnalysis :: - (Monad m) => + (MonadTypeChecker m) => SrcLoc -> ([CtTy SComp], [CtAM]) -> TyVars SComp -> @@ -53,14 +287,225 @@ rankAnalysis :: Maybe (TypeExp Exp VName) ) ] -rankAnalysis _loc (cs, _cs_am) tyVars artificial params body retdecl = do - pure - [ ( ( map void cs, - M.map (first (const ())) artificial, - fmap (second void) tyVars - ), - params, - body, - retdecl - ) - ] +rankAnalysis _ ([], []) tyVars artificial params body retdecl = do + (_, artificial', tyVars') <- substRankInfo ([], []) artificial tyVars mempty + pure [(([], artificial', tyVars'), params, body, retdecl)] +rankAnalysis loc (cs, cs_am) tyVars artificial params body retdecl = do + debugTraceM 3 $ + unlines + [ "##rankAnalysis", + "cs:", + unlines $ map prettyString cs, + "cs':", + unlines $ map prettyString cs' + ] + rank_maps <- solveRankILP loc (mkLinearProg cs' cs_am tyVars) + cts_tyvars' <- mapM (substRankInfo (cs, cs_am) artificial tyVars) rank_maps + let bodys = map (`updAM` body) rank_maps + params' = map ((`map` params) . updAMPat) rank_maps + retdecls = map ((<$> retdecl) . updAMTypeExp) rank_maps + pure $ L.zip4 cts_tyvars' params' bodys retdecls + where + cs' = + foldMap distribAndSplitCnstrs $ + foldMap distribAndSplitArrows cs + +type RankMap = M.Map VName Int + +substRankInfo :: + (MonadTypeChecker m) => + ([CtTy SComp], [CtAM]) -> + M.Map VName (CtType SComp) -> + TyVars SComp -> + RankMap -> + m ([CtTy ()], M.Map VName (CtType ()), TyVars ()) +substRankInfo (cs, _cs_am) artificial tyVars rankmap = do + ((cs', artificial', tyVars'), new_cs, new_tyVars) <- + runSubstT tyVars rankmap $ + (,,) + <$> traverse substRanksCt cs + <*> traverse substRanksType artificial + <*> substRanksTyVars tyVars + pure (cs' <> new_cs, artificial', new_tyVars <> tyVars') + +runSubstT :: + (MonadTypeChecker m) => + TyVars SComp -> + RankMap -> + SubstT m a -> + m (a, [CtTy ()], TyVars ()) +runSubstT tyVars rankmap (SubstT m) = do + let env = + SubstEnv + { envTyVars = tyVars, + envRanks = rankmap + } + + s = + SubstState + { substTyVars = mempty, + substNewVars = mempty, + substNewCts = mempty + } + (a, s') <- runReaderT (runStateT m s) env + pure (a, substNewCts s', substTyVars s') + +newtype SubstT m a = SubstT (StateT SubstState (ReaderT SubstEnv m) a) + deriving + ( Functor, + Applicative, + Monad, + MonadState SubstState, + MonadReader SubstEnv + ) + +data SubstEnv = SubstEnv + { envTyVars :: TyVars SComp, + envRanks :: RankMap + } + +data SubstState = SubstState + { substTyVars :: TyVars (), + substNewVars :: Map TyVar TyVar, + substNewCts :: [CtTy ()] + } + +instance MonadTrans SubstT where + lift = SubstT . lift . lift + +rankToShape :: (Monad m) => VName -> SubstT m (Shape ()) +rankToShape x = do + rs <- asks envRanks + pure $ Shape $ replicate (fromJust $ rs M.!? x) () + +newTyVar :: (MonadTypeChecker m) => TyVar -> SubstT m TyVar +newTyVar t = do + t' <- lift $ newTypeName (baseName t) + shape <- rankToShape t + loc <- asks ((locOf . snd . fromJust . (M.!? t)) . envTyVars) + modify $ \s -> + s + { substNewVars = M.insert t t' $ substNewVars s, + substNewCts = + CtEq + (Reason loc) + (Scalar (TypeVar mempty (QualName [] t) [])) + (arrayOf shape (Scalar (TypeVar mempty (QualName [] t') []))) + : substNewCts s + } + pure t' + +addRankInfo :: (MonadTypeChecker m) => TyVar -> SubstT m () +addRankInfo t = do + rs <- asks envRanks + if fromMaybe 0 (rs M.!? t) == 0 + then pure () + else do + new_vars <- gets substNewVars + maybe new_var (const $ pure ()) $ new_vars M.!? t + where + new_var = do + t' <- newTyVar t + old_tyvars <- asks envTyVars + let (level, tvinfo) = fromJust $ old_tyvars M.!? t + l = case tvinfo of + TyVarFree _ tvinfo_l -> tvinfo_l + _ -> Unlifted + tvinfo' <- substRanksTyVarInfo tvinfo + modify $ \s -> s {substTyVars = M.insert t' (level, tvinfo') $ substTyVars s} + modify $ \s -> s {substTyVars = M.insert t (level, TyVarFree (locOf tvinfo) l) $ substTyVars s} + +substRanksShape :: (Monad m) => Shape SComp -> SubstT m (Shape ()) +substRanksShape = foldM (\s d -> (s <>) <$> instDim d) mempty + where + instDim SDim = pure $ Shape [()] + instDim (SVar x) = rankToShape x + +substRanksType :: (MonadTypeChecker m) => TypeBase SComp u -> SubstT m (TypeBase () u) +substRanksType (Scalar (TypeVar vn (QualName qs x) targs)) = do + when (null qs) $ addRankInfo x + targs' <- mapM onTypeArg targs + pure $ Scalar $ TypeVar vn (QualName qs x) targs' + where + onTypeArg (TypeArgType t) = TypeArgType <$> substRanksType t + -- SVar cannot occur as argument to abstract ype. + onTypeArg (TypeArgDim _) = pure $ TypeArgDim () +substRanksType (Scalar (Arrow u p d ta (RetType retdims tr))) = do + ta' <- substRanksType ta + tr' <- substRanksType tr + pure $ Scalar (Arrow u p d ta' (RetType retdims tr')) +substRanksType (Scalar (Record fs)) = + Scalar . Record <$> traverse substRanksType fs +substRanksType (Scalar (Sum cs)) = + Scalar . Sum <$> (traverse . traverse) substRanksType cs +substRanksType (Scalar (Prim pt)) = pure $ Scalar $ Prim pt +substRanksType (Array u shape t) = do + shape' <- substRanksShape shape + t' <- substRanksType $ Scalar t + pure $ arrayOfWithAliases u shape' t' + +substRanksCt :: (MonadTypeChecker m) => CtTy SComp -> SubstT m (CtTy ()) +substRanksCt (CtEq r t1 t2) = + CtEq + <$> traverse substRanksType r + <*> substRanksType t1 + <*> substRanksType t2 + +substRanksTyVarInfo :: (MonadTypeChecker m) => TyVarInfo SComp -> SubstT m (TyVarInfo ()) +substRanksTyVarInfo (TyVarFree loc l) = pure $ TyVarFree loc l +substRanksTyVarInfo (TyVarPrim loc ts) = pure $ TyVarPrim loc ts +substRanksTyVarInfo (TyVarRecord loc fs) = + TyVarRecord loc <$> traverse substRanksType fs +substRanksTyVarInfo (TyVarSum loc cs) = + TyVarSum loc <$> traverse (traverse substRanksType) cs + +substRanksTyVars :: (MonadTypeChecker m) => TyVars SComp -> SubstT m (TyVars ()) +substRanksTyVars = traverse $ \(lvl, tv) -> (lvl,) <$> substRanksTyVarInfo tv + +updAM :: RankMap -> Exp -> Exp +updAM rank_map e = + case e of + AppExp (Apply f args loc) res -> + let f' = updAM rank_map f + args' = fmap (bimap (fmap $ second upd) (updAM rank_map)) args + in AppExp (Apply f' args' loc) res + AppExp (BinOp op t (x, Info (xv, xam)) (y, Info (yv, yam)) loc) res -> + AppExp (BinOp op t (updAM rank_map x, Info (xv, upd xam)) (updAM rank_map y, Info (yv, upd yam)) loc) res + OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext, am)) t2 loc -> + OpSectionRight + name + t + (updAM rank_map arg) + (Info (pa, t1a), Info (pb, t1b, argext, upd am)) + t2 + loc + OpSectionLeft name t arg (Info (pa, t1a, argext, am), Info (pb, t1b)) (ret, retext) loc -> + OpSectionLeft + name + t + (updAM rank_map arg) + (Info (pa, t1a, argext, upd am), Info (pb, t1b)) + (ret, retext) + loc + _ -> runIdentity $ astMap mapper e + where + dimToRank (Var (QualName [] x) _ _) = + replicate (rank_map M.! x) (TupLit mempty mempty) + dimToRank e' = error $ prettyString e' + shapeToRank = Shape . foldMap dimToRank + upd (AutoMap r m f) = + AutoMap (shapeToRank r) (shapeToRank m) (shapeToRank f) + mapper = identityMapper {mapOnExp = pure . updAM rank_map} + +updAMPat :: RankMap -> Pat ParamType -> Pat ParamType +updAMPat rank_map p = runIdentity $ astMap m p + where + m = identityMapper {mapOnExp = pure . updAM rank_map} + +updAMTypeExp :: + RankMap -> + TypeExp Exp VName -> + TypeExp Exp VName +updAMTypeExp rank_map te = runIdentity $ astMap m te + where + m = identityMapper {mapOnExp = pure . updAM rank_map} diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index cf65002f83..4314fc0018 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1214,14 +1214,6 @@ localChecks = void . check e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI (-x) t) (-x) t (loc1 <> loc2) _ -> error "Inferred type of int literal is not a number" - check e@(AppExp (BinOp (QualName [] v, _) _ (x, _) _ loc) _) - | baseName v == "==", - Array {} <- typeOf x, - baseTag v <= maxIntrinsicTag = do - warn loc $ - textwrap - "Comparing arrays with \"==\" is deprecated and will stop working in a future revision of the language." - recurse e check e = recurse e recurse = astMap identityMapper {mapOnExp = check} diff --git a/src/Language/Futhark/TypeChecker/Terms/Unsized.hs b/src/Language/Futhark/TypeChecker/Terms/Unsized.hs index a0794a4948..8c5f7a80cb 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Unsized.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Unsized.hs @@ -233,6 +233,11 @@ newTypeOverloaded :: SrcLoc -> Name -> [PrimType] -> TermM (TypeBase d NoUniquen newTypeOverloaded loc name pts = tyVarType NoUniqueness <$> newTyVarWith name (TyVarPrim (locOf loc) pts) +newSVar :: loc -> Name -> TermM SVar +newSVar _loc desc = do + i <- incCounter + newID $ mkTypeVarName desc i + newArtificial :: u -> TypeBase SComp u -> TermM (TypeBase Size u) newArtificial u t = do v <- newID "artificial" @@ -281,6 +286,12 @@ ctEq reason t1 t2 = t1' = t1 `setUniqueness` NoUniqueness t2' = t2 `setUniqueness` NoUniqueness +ctAM :: Reason (CtType SComp) -> SVar -> SVar -> Shape SComp -> TermM () +ctAM reason r m f = + modify $ \s -> s {termAM = ct : termAM s} + where + ct = CtAM reason r m f + localScope :: (TermScope -> TermScope) -> TermM a -> TermM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} @@ -647,23 +658,52 @@ checkApplyOne :: (Shape Size, Type) -> (Maybe Exp, Shape Size, Type) -> TermM (Type, AutoMap) -checkApplyOne loc fname (_fframe, ftype) (arg, _argframe, argtype) = do +checkApplyOne loc fname (fframe, ftype) (arg, argframe, argtype) = do (a, b) <- split ftype - let lhs = argtype - rhs = a + r <- newSVar loc "R" + m <- newSVar loc "M" + let unit_info = Info $ Scalar $ Prim Bool + r_var = Var (QualName [] r) unit_info mempty + m_var = Var (QualName [] m) unit_info mempty + lhs = arrayOf (toShape (SVar r)) argtype + rhs = arrayOf (toShape (SVar m)) a + ctAM (Reason (locOf loc)) r m $ fmap toSComp (toShape m_var <> fframe) let reason = case arg of Just arg' -> ReasonApply (locOf arg) fname arg' lhs rhs Nothing -> Reason (locOf loc) ctEq reason lhs rhs + debugTraceM 3 $ + unlines + [ "## checkApplyOne", + "## fname", + prettyString fname, + "## (fframe, ftype)", + prettyString (fframe, ftype), + "## (argframe, argtype)", + prettyString (argframe, argtype), + "## r", + prettyString r, + "## m", + prettyString m, + "## lhs", + prettyString lhs, + "## rhs", + prettyString rhs, + "## ret", + prettyString $ arrayOf (toShape (SVar m)) b + ] pure - ( b, + ( arrayOf (toShape (SVar m)) b, AutoMap - { autoRep = mempty, - autoMap = mempty, - autoFrame = mempty + { autoRep = toShape r_var, + autoMap = toShape m_var, + autoFrame = toShape m_var <> fframe } ) where + toSComp (Var (QualName [] x) _ _) = SVar x + toSComp _ = error "" + toShape = Shape . pure split (Scalar (Arrow _ _ _ a (RetType _ b))) = pure (a, b `setUniqueness` NoUniqueness) split (Array _u s t) = do @@ -1215,8 +1255,6 @@ doDefault :: Either [PrimType] (TypeBase () NoUniqueness) -> TermM (TypeBase () NoUniqueness) doDefault tyvars_at_toplevel v (Left pts) - | [pt] <- pts = - pure $ Scalar $ Prim pt | Signed Int32 `elem` pts = do when (v `elem` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." diff --git a/src/Language/Futhark/TypeChecker/TySolve.hs b/src/Language/Futhark/TypeChecker/TySolve.hs index 82afb9e190..75ebead21d 100644 --- a/src/Language/Futhark/TypeChecker/TySolve.hs +++ b/src/Language/Futhark/TypeChecker/TySolve.hs @@ -519,8 +519,8 @@ unionTyVars reason bcs v v_node t_node = do TyVarPrim t_loc t_pts ) -> let pts = L.intersect v_pts t_pts - in case pts of - [] -> + in if null pts + then pure $ Left ( locOf reason, @@ -530,7 +530,7 @@ unionTyVars reason bcs v v_node t_node = do "with type that must be one of" indent 2 (pretty t_pts) ) - _ -> pure $ Right $ Just $ Unsolved $ TyVarPrim t_loc pts + else pure $ Right $ Just $ Unsolved $ TyVarPrim t_loc pts (Unsolved (TyVarPrim _ v_pts), TyVarRecord {}) -> pure $ Left diff --git a/tests/automap/ambiguous0.fut b/tests/automap/ambiguous0.fut new file mode 100644 index 0000000000..8c1ec556c3 --- /dev/null +++ b/tests/automap/ambiguous0.fut @@ -0,0 +1,4 @@ +-- == +-- error: ambiguous + +def ambig (xss : [][]i32) = i64.sum (length xss) diff --git a/tests/automap/bool1.fut b/tests/automap/bool1.fut new file mode 100644 index 0000000000..f3fe08213e --- /dev/null +++ b/tests/automap/bool1.fut @@ -0,0 +1,6 @@ +-- == +-- entry: f +-- input { [true, true, false] [false, true, true] } +-- output { [true, true, true] } + +def f [m] (xs: [m]bool) (ys: [m]bool) = xs || ys diff --git a/tests/automap/combinations.fut b/tests/automap/combinations.fut new file mode 100644 index 0000000000..7d77e85abb --- /dev/null +++ b/tests/automap/combinations.fut @@ -0,0 +1,38 @@ +-- All the various ways one can imagine automapping a very simple program. + +def plus (x: i32) (y: i32) = x + y + +-- == +-- entry: vecint +-- input { [1,2,3] } output { [3,4,5] } + +entry vecint (x: []i32) = plus x 2 + +-- == +-- entry: vecvec +-- input { [1,2,3] } output { [2,4,6] } + +entry vecvec (x: []i32) = plus x x + +-- == +-- entry: matint +-- input { [[1,2],[3,4]] } output { [[3,4],[5,6]] } + +entry matint (x: [][]i32) = plus x 2 + +-- == +-- entry: matmat +-- input { [[1,2],[3,4]] } output { [[2,4],[6,8]] } + +entry matmat (x: [][]i32) = plus x x + +-- == +-- entry: matvec +-- input { [[1,2],[3,4]] [5,6] } output { [[6,8],[8,10]] } + +entry matvec (x: [][]i32) (y: []i32) = plus x y + +-- == +-- entry: vecvecvec +-- input { [1,2,3] } output { [3,6,9] } +entry vecvecvec (x: []i32) = (\x y z -> x + y + z) x x x diff --git a/tests/automap/equality1.fut b/tests/automap/equality1.fut new file mode 100644 index 0000000000..b2a173f30d --- /dev/null +++ b/tests/automap/equality1.fut @@ -0,0 +1,23 @@ +-- == +-- entry: bigger_to_smaller +-- input { [[1,2],[3,4]] [1,2] } +-- output { [[true, true], [false, false]] } + +-- == +-- entry: smaller_to_bigger +-- input { [[1,2],[3,4]] [1,2] } +-- output { [[true, true], [false, false]] } + +-- == +-- entry: smaller_to_bigger2 +-- input { [[1,2],[3,4]] 1 } +-- output { [[true,false],[false,false]]} + +entry bigger_to_smaller [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = + xss == ys + +entry smaller_to_bigger [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]bool = + ys == xss + +entry smaller_to_bigger2 [n] (xss : [n][n]i32) (z: i32) : [n][n]bool = + z == xss diff --git a/tests/automap/lambda.fut b/tests/automap/lambda.fut new file mode 100644 index 0000000000..1bb7ed26e3 --- /dev/null +++ b/tests/automap/lambda.fut @@ -0,0 +1,6 @@ +-- == +-- entry: main +-- random input { [10]f32 [10]f32 } + +entry main [n](xs: [n]f32) (ys: [n]f32): [n]f32 = + map2 (*) xs ys diff --git a/tests/automap/leetcode.fut b/tests/automap/leetcode.fut new file mode 100644 index 0000000000..43a50cb2b8 --- /dev/null +++ b/tests/automap/leetcode.fut @@ -0,0 +1,4 @@ +def outerprod f x y = map (f >-> flip map y) x +def bidd A = outerprod (==) (indices A) (indices A) +def xmat A = bidd A || reverse (bidd A) +def check_matrix (A : [][]i32) = xmat A == (A != 0) |> flatten |> and diff --git a/tests/automap/map0.fut b/tests/automap/map0.fut new file mode 100644 index 0000000000..a5ab0887ae --- /dev/null +++ b/tests/automap/map0.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [0,1,2,3] } +-- output { [1,2,3,4] } + +def automap 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = f as + +entry main (x: []i32) = automap (+1) x diff --git a/tests/automap/mri-q-qr.fut b/tests/automap/mri-q-qr.fut new file mode 100644 index 0000000000..8004f7da5d --- /dev/null +++ b/tests/automap/mri-q-qr.fut @@ -0,0 +1,2 @@ +def qr [numX][numK] (expArgs : [numX][numK]f32) (phiMag : [numK]f32) : [numX]f32 = + f32.sum (f32.cos expArgs * phiMag) diff --git a/tests/automap/mri-q.fut b/tests/automap/mri-q.fut new file mode 100644 index 0000000000..270e18195a --- /dev/null +++ b/tests/automap/mri-q.fut @@ -0,0 +1,41 @@ +-- == +-- entry: main +-- random input { [12]f32 [12]f32 [12]f32 [10]f32 [10]f32 [10]f32 [12]f32 [12]f32 } +-- output { true } + +def main_orig [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numX]f32, [numX]f32) = + let phiMag = map2 (\r i -> r*r + i*i) phiR phiI + let expArgs = map3 (\x_e y_e z_e -> + map (2.0f32*f32.pi*) + (map3 (\kx_e ky_e kz_e -> + kx_e * x_e + ky_e * y_e + kz_e * z_e) + kx ky kz)) + x y z + let qr = map1 (map f32.cos >-> map2 (*) phiMag >-> f32.sum) expArgs + let qi = map1 (map f32.sin >-> map2 (*) phiMag >-> f32.sum) expArgs + in (qr, qi) + +def main_am [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) + : ([numX]f32, [numX]f32) = + let phiMag = phiR * phiR + phiI * phiI + let expArgs = map3 (\x_e y_e z_e -> + 2.0*f32.pi*(kx*x_e + ky*y_e + kz*z_e)) + x y z + let qr = f32.sum (f32.cos expArgs * phiMag) + let qi = f32.sum (f32.sin expArgs * phiMag) + in (qr, qi) + +entry main [numK][numX] + (kx: [numK]f32) (ky: [numK]f32) (kz: [numK]f32) + (x: [numX]f32) (y: [numX]f32) (z: [numX]f32) + (phiR: [numK]f32) (phiI: [numK]f32) = + let (qr, qi) = main_orig kx ky kz x y z phiR phiI + let (qr_am, qi_am) = main_am kx ky kz x y z phiR phiI + in and (qr == qr_am && qi == qi_am) diff --git a/tests/automap/operator1.fut b/tests/automap/operator1.fut new file mode 100644 index 0000000000..464a8b79c4 --- /dev/null +++ b/tests/automap/operator1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [10,20] } +-- output { [[11, 22],[13, 24]] } + +def (+^) [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = xs + ys + +--entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = +-- xss +^ ys diff --git a/tests/automap/optionpricing.fut b/tests/automap/optionpricing.fut new file mode 100644 index 0000000000..c4c916521f --- /dev/null +++ b/tests/automap/optionpricing.fut @@ -0,0 +1,78 @@ +-- == +-- entry: sobolIndR +-- random input { [12][10]i32 i32 } +-- output { true } + +-- == +-- entry: sobolRecI +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +-- == +-- entry: sobolReci2 +-- random input { [12][10]i32 [12]i32 i32} +-- output { true } + +def grayCode(x: i32): i32 = (x >> 1) ^ x + +def testBit(n: i32, ind: i32): bool = + let t = (1 << ind) in (n & t) == t + +def xorInds [num_bits] (n: i32) (dir_vs: [num_bits]i32): i32 = + let reldv_vals = map (\(dv: i32, i): i32 -> + if testBit(grayCode(n),i32.i64 i) + then dv else 0 + ) (zip (dir_vs) (iota(num_bits)) ) in + reduce (^) 0 (reldv_vals ) + + +def sobolIndI [len] (dir_vs: [len][]i32, n: i32 ): [len]i32 = + map (xorInds(n)) (dir_vs ) + +def index_of_least_significant_0(num_bits: i32, n: i32): i32 = + let (goon,k) = (true,0) in + let (_,k,_) = loop ((goon,k,n)) for i < num_bits do + if(goon) + then if (n & 1) == 1 + then (true, k+1, n>>1) + else (false,k, n ) + else (false,k, n ) + in k + +def recM [len][num_bits] (sob_dirs: [len][num_bits]i32, i: i32 ): [len]i32 = + let bit= index_of_least_significant_0(i32.i64 num_bits,i) in + map (\(row: []i32): i32 -> row[bit]) (sob_dirs ) + +def sobolIndR_orig [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = map (xorInds n) dir_vs + in map (\x -> f32.i32(x) / divisor) arri + +def sobolRecI_orig [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in map2 (\vct_row prev -> vct_row[bit] ^ prev) sob_dir_vs prev + +def sobolReci2_orig [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + let col = recM(sob_dirs, i) + in map2 (^) prev col + +def sobolIndR_am [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): [m]f32 = + let divisor = 2.0 ** f32.i64(num_bits) + let arri = xorInds n dir_vs + in f32.i32 arri / divisor + +def sobolRecI_am [num_bits][n] (sob_dir_vs: [n][num_bits]i32, prev: [n]i32, x: i32): [n]i32 = + let bit = index_of_least_significant_0(i32.i64 num_bits, x) + in sob_dir_vs[:,bit] ^ prev + +def sobolReci2_am [n][num_bits] (sob_dirs: [n][num_bits]i32, prev: [n]i32, i: i32): [n]i32= + prev ^ recM(sob_dirs, i) + +entry sobolIndR [m][num_bits] (dir_vs: [m][num_bits]i32) (n: i32): bool = + and (sobolIndR_orig dir_vs n == sobolIndR_am dir_vs n) + +entry sobolRecI [num_bits][n] (sob_dir_vs: [n][num_bits]i32) (prev: [n]i32) (x: i32): bool = + and (sobolRecI_orig (sob_dir_vs, prev, x) == sobolRecI_am (sob_dir_vs, prev, x)) + +entry sobolReci2 [n][num_bits] (sob_dirs: [n][num_bits]i32) (prev: [n]i32) (i: i32): bool = + and (sobolReci2_orig (sob_dirs, prev, i) == sobolReci2_am (sob_dirs, prev, i)) diff --git a/tests/automap/pagerank.fut b/tests/automap/pagerank.fut new file mode 100644 index 0000000000..3552990144 --- /dev/null +++ b/tests/automap/pagerank.fut @@ -0,0 +1,18 @@ +-- == +-- entry: calculate_dangling_ranks +-- random input { [12]f32 [12]i32} +-- output { true } + +def calculate_dangling_ranks_orig [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let zipped = zip sizes ranks + let weights = map (\(size, rank) -> if size == 0 then rank else 0f32) zipped + let total = f32.sum weights / f32.i64 n + in map (+total) ranks + +def calculate_dangling_ranks_am [n] (ranks: [n]f32) (sizes: [n]i32): *[]f32 = + let weights = f32.bool (sizes == 0) * ranks + let total = f32.sum weights / f32.i64 n + in ranks + total + +entry calculate_dangling_ranks [n] (ranks: [n]f32) (sizes: [n]i32): bool = + and (calculate_dangling_ranks_orig ranks sizes == calculate_dangling_ranks_am ranks sizes) diff --git a/tests/automap/project.fut b/tests/automap/project.fut new file mode 100644 index 0000000000..2902d0565a --- /dev/null +++ b/tests/automap/project.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in xsys.0 ++ xsys.1 + diff --git a/tests/automap/projsec1.fut b/tests/automap/projsec1.fut new file mode 100644 index 0000000000..485c977bc5 --- /dev/null +++ b/tests/automap/projsec1.fut @@ -0,0 +1,9 @@ +-- == +-- entry: main +-- input { [1,2,3] [4,5,6] } +-- output { [1,2,3,4,5,6] } + +entry main [n] (xs: [n]i32) (ys: [n]i32) : []i32 = + let xsys = zip xs ys + in (.0) xsys ++ (.1) xsys + diff --git a/tests/automap/same_typevar.fut b/tests/automap/same_typevar.fut new file mode 100644 index 0000000000..260a00b785 --- /dev/null +++ b/tests/automap/same_typevar.fut @@ -0,0 +1,16 @@ +-- == +-- tags { no_wasm } +-- entry: big_to_small +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +-- == +-- entry: small_to_big +-- no_wasm compiled input { [[1,2],[3,4]] [1,2] 3 } + +def f 'a (x: a) (y: a) (z: a) = (x, y, z) + +entry big_to_small [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f xss ys z + +entry small_to_big [n] (xss : [n][n]i32) (ys: [n]i32) (z: i32) : [n][n](i32,i32,i32) = + f z ys xss diff --git a/tests/automap/sgemm.fut b/tests/automap/sgemm.fut new file mode 100644 index 0000000000..a31ce0188e --- /dev/null +++ b/tests/automap/sgemm.fut @@ -0,0 +1,32 @@ +-- == +-- entry: main +-- random input { [5][10]f32 [10][3]f32 [5][3]f32 f32 f32 } +-- output { true } + +def mult_orig [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + let dotprod xs ys = f32.sum (map2 (*) xs ys) + in map (\xs -> map (dotprod xs) (transpose yss)) xss + +def add [n][m] (xss: [n][m]f32, yss: [n][m]f32): [n][m]f32 = + map2 (map2 (+)) xss yss + +def scale [n][m] (xss: [n][m]f32, a: f32): [n][m]f32 = + map (map1 (*a)) xss + +def main_orig [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + add(scale(css,beta), scale(mult_orig(ass,bss), alpha)) + + +def mult_am [n][m][p] (xss: [n][m]f32, yss: [m][p]f32): [n][p]f32 = + f32.sum ((transpose (replicate p xss)) * (replicate n (transpose yss))) + +def main_am [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) + : [n][p]f32 = + css*beta + mult_am(ass,bss)*alpha + +entry main [n][m][p] (ass: [n][m]f32) (bss: [m][p]f32) (css: [n][p]f32) + (alpha: f32) (beta: f32) = + and (and (main_orig ass bss css alpha beta == main_am ass bss css alpha beta)) diff --git a/tests/automap/simple1.fut b/tests/automap/simple1.fut new file mode 100644 index 0000000000..f8833bb3b6 --- /dev/null +++ b/tests/automap/simple1.fut @@ -0,0 +1,7 @@ +-- == +-- entry: main +-- input { [1,2] 10 } +-- output { [11, 12] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + xs + y diff --git a/tests/automap/simple2.fut b/tests/automap/simple2.fut new file mode 100644 index 0000000000..ac57abcbe0 --- /dev/null +++ b/tests/automap/simple2.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + xss + ys + diff --git a/tests/automap/simple3.fut b/tests/automap/simple3.fut new file mode 100644 index 0000000000..adc60bd43f --- /dev/null +++ b/tests/automap/simple3.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { [[1,2],[3,4]] [1,1] } +-- output { [[2,3],[4,5]] } + +entry main [n] (xss : [n][n]i32) (ys: [n]i32) : [n][n]i32 = + ys + xss + diff --git a/tests/automap/simple4.fut b/tests/automap/simple4.fut new file mode 100644 index 0000000000..d94bbe4a6b --- /dev/null +++ b/tests/automap/simple4.fut @@ -0,0 +1,8 @@ +-- == +-- entry: main +-- input { 3 [1,1] [[1,2],[3,4]] } +-- output { [[5,6],[7,8]] } + +entry main [n] (x : i32) (ys: [n]i32) (zss : [n][n]i32) : [n][n]i32 = + x + ys + zss + diff --git a/tests/automap/simple5.fut b/tests/automap/simple5.fut new file mode 100644 index 0000000000..46610e6567 --- /dev/null +++ b/tests/automap/simple5.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1,2,3] 4 } +-- output { [5, 6, 7] } + +entry main [n] (xs: [n]i32) (y : i32) : [n]i32 = + (\x y -> x + y) xs y diff --git a/tests/issue1599.fut b/tests/issue1599.fut deleted file mode 100644 index 3ce47c38b1..0000000000 --- a/tests/issue1599.fut +++ /dev/null @@ -1,4 +0,0 @@ --- == --- error: Occurs - -let bad a f = f a f diff --git a/tests/issue1926.fut b/tests/issue1926.fut index feaef47175..6f79db9bb4 100644 --- a/tests/issue1926.fut +++ b/tests/issue1926.fut @@ -1,12 +1,11 @@ -- == --- error: cannot unify type with constructors +-- error: cannot match value type found = #found i32 | #not_found def main = let o = map (\x -> if (x > 3) then (#found x) else (#not_found)) [0, 1, 2, 3, 4] - let u = - match o + let u = match o case #found x -> x case #not_found -> -1 in u diff --git a/tests/types/inference5.fut b/tests/types/inference5.fut deleted file mode 100644 index d05b9084aa..0000000000 --- a/tests/types/inference5.fut +++ /dev/null @@ -1,7 +0,0 @@ --- let should not be generalised --- == --- error: Cannot apply "apply" - -def main x = - let apply f x = f x - in apply (apply (i32.+) x) x From 206f8e071fc5114bbbc22da9151e5bae60dd40d2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 30 Jul 2025 15:10:08 +0200 Subject: [PATCH 294/296] Use non-automap futhark-benchmarks --- futhark-benchmarks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/futhark-benchmarks b/futhark-benchmarks index 0d427b9483..13d3cb5cb2 160000 --- a/futhark-benchmarks +++ b/futhark-benchmarks @@ -1 +1 @@ -Subproject commit 0d427b94838beea4d6512f7639860e5b967ce7bc +Subproject commit 13d3cb5cb2c887adca2bf4fbd02f9e866436cbfe From 66bedb37b50cbe95a84929d20a2f25124c38fb9f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 30 Jul 2025 16:31:45 +0200 Subject: [PATCH 295/296] Make this more generous. --- src/Language/Futhark/TypeChecker/Unify.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index fa53949c5c..7292ae2638 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -810,11 +810,13 @@ unifyMostCommon usage t1 t2 = do mapM_ (uncurry $ onDims bcs bound nonrigid) es onDims bcs _ nonrigid (Var v1 _ _) e2 | Just lvl1 <- nonrigid (qualLeaf v1), - expLevel e2 < lvl1 = + expLevel e2 <= lvl1, + not $ qualLeaf v1 `S.member` fvVars (freeInExp e2) = linkVarToDim usage bcs (qualLeaf v1) lvl1 e2 onDims bcs _ nonrigid e1 (Var v2 _ _) | Just lvl2 <- nonrigid (qualLeaf v2), - expLevel e1 < lvl2 = + expLevel e1 <= lvl2, + not $ qualLeaf v2 `S.member` fvVars (freeInExp e1) = linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 onDims _ _ _ _ _ = pure () From ea7607bc728c343204629cfcaee925a03dca02cb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 30 Jul 2025 16:40:05 +0200 Subject: [PATCH 296/296] Fix expected error. --- tests/issue1926.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/issue1926.fut b/tests/issue1926.fut index feaef47175..63df4afd2f 100644 --- a/tests/issue1926.fut +++ b/tests/issue1926.fut @@ -1,5 +1,5 @@ -- == --- error: cannot unify type with constructors +-- error: Cannot unify type with constructors type found = #found i32 | #not_found