Compiler projects using llvm
//===----- TypePromotion.cpp ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// This is an opcode based type promotion pass for small types that would
/// otherwise be promoted during legalisation. This works around the limitations
/// of selection dag for cyclic regions. The search begins from icmp
/// instructions operands where a tree, consisting of non-wrapping or safe
/// wrapping instructions, is built, checked and promoted if possible.
///
//===----------------------------------------------------------------------===//

#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Target/TargetMachine.h"

#define DEBUG_TYPE "type-promotion"
#define PASS_NAME "Type Promotion"

using namespace llvm;

static cl::opt<bool> DisablePromotion("disable-type-promotion", cl::Hidden,
                                      cl::init(false),
                                      cl::desc("Disable type promotion pass"));

// The goal of this pass is to enable more efficient code generation for
// operations on narrow types (i.e. types with < 32-bits) and this is a
// motivating IR code example:
//
//   define hidden i32 @cmp(i8 zeroext) {
//     %2 = add i8 %0, -49
//     %3 = icmp ult i8 %2, 3
//     ..
//   }
//
// The issue here is that i8 is type-legalized to i32 because i8 is not a
// legal type. Thus, arithmetic is done in integer-precision, but then the
// byte value is masked out as follows:
//
//   t19: i32 = add t4, Constant:i32<-49>
//     t24: i32 = and t19, Constant:i32<255>
//
// Consequently, we generate code like this:
//
//   subs  r0, #49
//   uxtb  r1, r0
//   cmp r1, #3
//
// This shows that masking out the byte value results in generation of
// the UXTB instruction. This is not optimal as r0 already contains the byte
// value we need, and so instead we can just generate:
//
//   sub.w r1, r0, #49
//   cmp r1, #3
//
// We achieve this by type promoting the IR to i32 like so for this example:
//
//   define i32 @cmp(i8 zeroext %c) {
//     %0 = zext i8 %c to i32
//     %c.off = add i32 %0, -49
//     %1 = icmp ult i32 %c.off, 3
//     ..
//   }
//
// For this to be valid and legal, we need to prove that the i32 add is
// producing the same value as the i8 addition, and that e.g. no overflow
// happens.
//
// A brief sketch of the algorithm and some terminology.
// We pattern match interesting IR patterns:
// - which have "sources": instructions producing narrow values (i8, i16), and
// - they have "sinks": instructions consuming these narrow values.
//
// We collect all instruction connecting sources and sinks in a worklist, so
// that we can mutate these instruction and perform type promotion when it is
// legal to do so.

namespace {
class IRPromoter {
  LLVMContext &Ctx;
  unsigned PromotedWidth = 0;
  SetVector<Value *> &Visited;
  SetVector<Value *> &Sources;
  SetVector<Instruction *> &Sinks;
  SmallPtrSetImpl<Instruction *> &SafeWrap;
  IntegerType *ExtTy = nullptr;
  SmallPtrSet<Value *, 8> NewInsts;
  SmallPtrSet<Instruction *, 4> InstsToRemove;
  DenseMap<Value *, SmallVector<Type *, 4>> TruncTysMap;
  SmallPtrSet<Value *, 8> Promoted;

  void ReplaceAllUsersOfWith(Value *From, Value *To);
  void ExtendSources();
  void ConvertTruncs();
  void PromoteTree();
  void TruncateSinks();
  void Cleanup();

public:
  IRPromoter(LLVMContext &C, unsigned Width,
             SetVector<Value *> &visited, SetVector<Value *> &sources,
             SetVector<Instruction *> &sinks,
             SmallPtrSetImpl<Instruction *> &wrap)
      : Ctx(C), PromotedWidth(Width), Visited(visited),
        Sources(sources), Sinks(sinks), SafeWrap(wrap) {
    ExtTy = IntegerType::get(Ctx, PromotedWidth);
  }

  void Mutate();
};

class TypePromotion : public FunctionPass {
  unsigned TypeSize = 0;
  LLVMContext *Ctx = nullptr;
  unsigned RegisterBitWidth = 0;
  SmallPtrSet<Value *, 16> AllVisited;
  SmallPtrSet<Instruction *, 8> SafeToPromote;
  SmallPtrSet<Instruction *, 4> SafeWrap;

  // Does V have the same size result type as TypeSize.
  bool EqualTypeSize(Value *V);
  // Does V have the same size, or narrower, result type as TypeSize.
  bool LessOrEqualTypeSize(Value *V);
  // Does V have a result type that is wider than TypeSize.
  bool GreaterThanTypeSize(Value *V);
  // Does V have a result type that is narrower than TypeSize.
  bool LessThanTypeSize(Value *V);
  // Should V be a leaf in the promote tree?
  bool isSource(Value *V);
  // Should V be a root in the promotion tree?
  bool isSink(Value *V);
  // Should we change the result type of V? It will result in the users of V
  // being visited.
  bool shouldPromote(Value *V);
  // Is I an add or a sub, which isn't marked as nuw, but where a wrapping
  // result won't affect the computation?
  bool isSafeWrap(Instruction *I);
  // Can V have its integer type promoted, or can the type be ignored.
  bool isSupportedType(Value *V);
  // Is V an instruction with a supported opcode or another value that we can
  // handle, such as constants and basic blocks.
  bool isSupportedValue(Value *V);
  // Is V an instruction thats result can trivially promoted, or has safe
  // wrapping.
  bool isLegalToPromote(Value *V);
  bool TryToPromote(Value *V, unsigned PromotedWidth);

public:
  static char ID;

  TypePromotion() : FunctionPass(ID) {}

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetTransformInfoWrapperPass>();
    AU.addRequired<TargetPassConfig>();
    AU.setPreservesCFG();
  }

  StringRef getPassName() const override { return PASS_NAME; }

  bool runOnFunction(Function &F) override;
};

} // namespace

static bool GenerateSignBits(Instruction *I) {
  unsigned Opc = I->getOpcode();
  return Opc == Instruction::AShr || Opc == Instruction::SDiv ||
         Opc == Instruction::SRem || Opc == Instruction::SExt;
}

bool TypePromotion::EqualTypeSize(Value *V) {
  return V->getType()->getScalarSizeInBits() == TypeSize;
}

bool TypePromotion::LessOrEqualTypeSize(Value *V) {
  return V->getType()->getScalarSizeInBits() <= TypeSize;
}

bool TypePromotion::GreaterThanTypeSize(Value *V) {
  return V->getType()->getScalarSizeInBits() > TypeSize;
}

bool TypePromotion::LessThanTypeSize(Value *V) {
  return V->getType()->getScalarSizeInBits() < TypeSize;
}

/// Return true if the given value is a source in the use-def chain, producing
/// a narrow 'TypeSize' value. These values will be zext to start the promotion
/// of the tree to i32. We guarantee that these won't populate the upper bits
/// of the register. ZExt on the loads will be free, and the same for call
/// return values because we only accept ones that guarantee a zeroext ret val.
/// Many arguments will have the zeroext attribute too, so those would be free
/// too.
bool TypePromotion::isSource(Value *V) {
  if (!isa<IntegerType>(V->getType()))
    return false;

  // TODO Allow zext to be sources.
  if (isa<Argument>(V))
    return true;
  else if (isa<LoadInst>(V))
    return true;
  else if (isa<BitCastInst>(V))
    return true;
  else if (auto *Call = dyn_cast<CallInst>(V))
    return Call->hasRetAttr(Attribute::AttrKind::ZExt);
  else if (auto *Trunc = dyn_cast<TruncInst>(V))
    return EqualTypeSize(Trunc);
  return false;
}

/// Return true if V will require any promoted values to be truncated for the
/// the IR to remain valid. We can't mutate the value type of these
/// instructions.
bool TypePromotion::isSink(Value *V) {
  // TODO The truncate also isn't actually necessary because we would already
  // proved that the data value is kept within the range of the original data
  // type. We currently remove any truncs inserted for handling zext sinks.

  // Sinks are:
  // - points where the value in the register is being observed, such as an
  //   icmp, switch or store.
  // - points where value types have to match, such as calls and returns.
  // - zext are included to ease the transformation and are generally removed
  //   later on.
  if (auto *Store = dyn_cast<StoreInst>(V))
    return LessOrEqualTypeSize(Store->getValueOperand());
  if (auto *Return = dyn_cast<ReturnInst>(V))
    return LessOrEqualTypeSize(Return->getReturnValue());
  if (auto *ZExt = dyn_cast<ZExtInst>(V))
    return GreaterThanTypeSize(ZExt);
  if (auto *Switch = dyn_cast<SwitchInst>(V))
    return LessThanTypeSize(Switch->getCondition());
  if (auto *ICmp = dyn_cast<ICmpInst>(V))
    return ICmp->isSigned() || LessThanTypeSize(ICmp->getOperand(0));

  return isa<CallInst>(V);
}

/// Return whether this instruction can safely wrap.
bool TypePromotion::isSafeWrap(Instruction *I) {
  // We can support a potentially wrapping instruction (I) if:
  // - It is only used by an unsigned icmp.
  // - The icmp uses a constant.
  // - The wrapping value (I) is decreasing, i.e would underflow - wrapping
  //   around zero to become a larger number than before.
  // - The wrapping instruction (I) also uses a constant.
  //
  // We can then use the two constants to calculate whether the result would
  // wrap in respect to itself in the original bitwidth. If it doesn't wrap,
  // just underflows the range, the icmp would give the same result whether the
  // result has been truncated or not. We calculate this by:
  // - Zero extending both constants, if needed, to RegisterBitWidth.
  // - Take the absolute value of I's constant, adding this to the icmp const.
  // - Check that this value is not out of range for small type. If it is, it
  //   means that it has underflowed enough to wrap around the icmp constant.
  //
  // For example:
  //
  // %sub = sub i8 %a, 2
  // %cmp = icmp ule i8 %sub, 254
  //
  // If %a = 0, %sub = -2 == FE == 254
  // But if this is evalulated as a i32
  // %sub = -2 == FF FF FF FE == 4294967294
  // So the unsigned compares (i8 and i32) would not yield the same result.
  //
  // Another way to look at it is:
  // %a - 2 <= 254
  // %a + 2 <= 254 + 2
  // %a <= 256
  // And we can't represent 256 in the i8 format, so we don't support it.
  //
  // Whereas:
  //
  // %sub i8 %a, 1
  // %cmp = icmp ule i8 %sub, 254
  //
  // If %a = 0, %sub = -1 == FF == 255
  // As i32:
  // %sub = -1 == FF FF FF FF == 4294967295
  //
  // In this case, the unsigned compare results would be the same and this
  // would also be true for ult, uge and ugt:
  // - (255 < 254) == (0xFFFFFFFF < 254) == false
  // - (255 <= 254) == (0xFFFFFFFF <= 254) == false
  // - (255 > 254) == (0xFFFFFFFF > 254) == true
  // - (255 >= 254) == (0xFFFFFFFF >= 254) == true
  //
  // To demonstrate why we can't handle increasing values:
  //
  // %add = add i8 %a, 2
  // %cmp = icmp ult i8 %add, 127
  //
  // If %a = 254, %add = 256 == (i8 1)
  // As i32:
  // %add = 256
  //
  // (1 < 127) != (256 < 127)

  unsigned Opc = I->getOpcode();
  if (Opc != Instruction::Add && Opc != Instruction::Sub)
    return false;

  if (!I->hasOneUse() || !isa<ICmpInst>(*I->user_begin()) ||
      !isa<ConstantInt>(I->getOperand(1)))
    return false;

  // Don't support an icmp that deals with sign bits.
  auto *CI = cast<ICmpInst>(*I->user_begin());
  if (CI->isSigned() || CI->isEquality())
    return false;

  ConstantInt *ICmpConstant = nullptr;
  if (auto *Const = dyn_cast<ConstantInt>(CI->getOperand(0)))
    ICmpConstant = Const;
  else if (auto *Const = dyn_cast<ConstantInt>(CI->getOperand(1)))
    ICmpConstant = Const;
  else
    return false;

  const APInt &ICmpConst = ICmpConstant->getValue();
  APInt OverflowConst = cast<ConstantInt>(I->getOperand(1))->getValue();
  if (Opc == Instruction::Sub)
    OverflowConst = -OverflowConst;
  if (!OverflowConst.isNonPositive())
    return false;

  // Using C1 = OverflowConst and C2 = ICmpConst, we can either prove that:
  //   zext(x) + sext(C1) <u zext(C2)  if C1 < 0 and C1 >s C2
  //   zext(x) + sext(C1) <u sext(C2)  if C1 < 0 and C1 <=s C2
  if (OverflowConst.sgt(ICmpConst)) {
    LLVM_DEBUG(dbgs() << "IR Promotion: Allowing safe overflow for sext "
                      << "const of " << *I << "\n");
    SafeWrap.insert(I);
    return true;
  } else {
    LLVM_DEBUG(dbgs() << "IR Promotion: Allowing safe overflow for sext "
                      << "const of " << *I << " and " << *CI << "\n");
    SafeWrap.insert(I);
    SafeWrap.insert(CI);
    return true;
  }
  return false;
}

bool TypePromotion::shouldPromote(Value *V) {
  if (!isa<IntegerType>(V->getType()) || isSink(V))
    return false;

  if (isSource(V))
    return true;

  auto *I = dyn_cast<Instruction>(V);
  if (!I)
    return false;

  if (isa<ICmpInst>(I))
    return false;

  return true;
}

/// Return whether we can safely mutate V's type to ExtTy without having to be
/// concerned with zero extending or truncation.
static bool isPromotedResultSafe(Instruction *I) {
  if (GenerateSignBits(I))
    return false;

  if (!isa<OverflowingBinaryOperator>(I))
    return true;

  return I->hasNoUnsignedWrap();
}

void IRPromoter::ReplaceAllUsersOfWith(Value *From, Value *To) {
  SmallVector<Instruction *, 4> Users;
  Instruction *InstTo = dyn_cast<Instruction>(To);
  bool ReplacedAll = true;

  LLVM_DEBUG(dbgs() << "IR Promotion: Replacing " << *From << " with " << *To
                    << "\n");

  for (Use &U : From->uses()) {
    auto *User = cast<Instruction>(U.getUser());
    if (InstTo && User->isIdenticalTo(InstTo)) {
      ReplacedAll = false;
      continue;
    }
    Users.push_back(User);
  }

  for (auto *U : Users)
    U->replaceUsesOfWith(From, To);

  if (ReplacedAll)
    if (auto *I = dyn_cast<Instruction>(From))
      InstsToRemove.insert(I);
}

void IRPromoter::ExtendSources() {
  IRBuilder<> Builder{Ctx};

  auto InsertZExt = [&](Value *V, Instruction *InsertPt) {
    assert(V->getType() != ExtTy && "zext already extends to i32");
    LLVM_DEBUG(dbgs() << "IR Promotion: Inserting ZExt for " << *V << "\n");
    Builder.SetInsertPoint(InsertPt);
    if (auto *I = dyn_cast<Instruction>(V))
      Builder.SetCurrentDebugLocation(I->getDebugLoc());

    Value *ZExt = Builder.CreateZExt(V, ExtTy);
    if (auto *I = dyn_cast<Instruction>(ZExt)) {
      if (isa<Argument>(V))
        I->moveBefore(InsertPt);
      else
        I->moveAfter(InsertPt);
      NewInsts.insert(I);
    }

    ReplaceAllUsersOfWith(V, ZExt);
  };

  // Now, insert extending instructions between the sources and their users.
  LLVM_DEBUG(dbgs() << "IR Promotion: Promoting sources:\n");
  for (auto *V : Sources) {
    LLVM_DEBUG(dbgs() << " - " << *V << "\n");
    if (auto *I = dyn_cast<Instruction>(V))
      InsertZExt(I, I);
    else if (auto *Arg = dyn_cast<Argument>(V)) {
      BasicBlock &BB = Arg->getParent()->front();
      InsertZExt(Arg, &*BB.getFirstInsertionPt());
    } else {
      llvm_unreachable("unhandled source that needs extending");
    }
    Promoted.insert(V);
  }
}

void IRPromoter::PromoteTree() {
  LLVM_DEBUG(dbgs() << "IR Promotion: Mutating the tree..\n");

  // Mutate the types of the instructions within the tree. Here we handle
  // constant operands.
  for (auto *V : Visited) {
    if (Sources.count(V))
      continue;

    auto *I = cast<Instruction>(V);
    if (Sinks.count(I))
      continue;

    for (unsigned i = 0, e = I->getNumOperands(); i < e; ++i) {
      Value *Op = I->getOperand(i);
      if ((Op->getType() == ExtTy) || !isa<IntegerType>(Op->getType()))
        continue;

      if (auto *Const = dyn_cast<ConstantInt>(Op)) {
        // For subtract, we don't need to sext the constant. We only put it in
        // SafeWrap because SafeWrap.size() is used elsewhere.
        // For cmp, we need to sign extend a constant appearing in either
        // operand. For add, we should only sign extend the RHS.
        Constant *NewConst = (SafeWrap.contains(I) &&
                              (I->getOpcode() == Instruction::ICmp || i == 1) &&
                              I->getOpcode() != Instruction::Sub)
                                 ? ConstantExpr::getSExt(Const, ExtTy)
                                 : ConstantExpr::getZExt(Const, ExtTy);
        I->setOperand(i, NewConst);
      } else if (isa<UndefValue>(Op))
        I->setOperand(i, ConstantInt::get(ExtTy, 0));
    }

    // Mutate the result type, unless this is an icmp or switch.
    if (!isa<ICmpInst>(I) && !isa<SwitchInst>(I)) {
      I->mutateType(ExtTy);
      Promoted.insert(I);
    }
  }
}

void IRPromoter::TruncateSinks() {
  LLVM_DEBUG(dbgs() << "IR Promotion: Fixing up the sinks:\n");

  IRBuilder<> Builder{Ctx};

  auto InsertTrunc = [&](Value *V, Type *TruncTy) -> Instruction * {
    if (!isa<Instruction>(V) || !isa<IntegerType>(V->getType()))
      return nullptr;

    if ((!Promoted.count(V) && !NewInsts.count(V)) || Sources.count(V))
      return nullptr;

    LLVM_DEBUG(dbgs() << "IR Promotion: Creating " << *TruncTy << " Trunc for "
                      << *V << "\n");
    Builder.SetInsertPoint(cast<Instruction>(V));
    auto *Trunc = dyn_cast<Instruction>(Builder.CreateTrunc(V, TruncTy));
    if (Trunc)
      NewInsts.insert(Trunc);
    return Trunc;
  };

  // Fix up any stores or returns that use the results of the promoted
  // chain.
  for (auto *I : Sinks) {
    LLVM_DEBUG(dbgs() << "IR Promotion: For Sink: " << *I << "\n");

    // Handle calls separately as we need to iterate over arg operands.
    if (auto *Call = dyn_cast<CallInst>(I)) {
      for (unsigned i = 0; i < Call->arg_size(); ++i) {
        Value *Arg = Call->getArgOperand(i);
        Type *Ty = TruncTysMap[Call][i];
        if (Instruction *Trunc = InsertTrunc(Arg, Ty)) {
          Trunc->moveBefore(Call);
          Call->setArgOperand(i, Trunc);
        }
      }
      continue;
    }

    // Special case switches because we need to truncate the condition.
    if (auto *Switch = dyn_cast<SwitchInst>(I)) {
      Type *Ty = TruncTysMap[Switch][0];
      if (Instruction *Trunc = InsertTrunc(Switch->getCondition(), Ty)) {
        Trunc->moveBefore(Switch);
        Switch->setCondition(Trunc);
      }
      continue;
    }

    // Don't insert a trunc for a zext which can still legally promote.
    if (auto ZExt = dyn_cast<ZExtInst>(I))
      if (ZExt->getType()->getScalarSizeInBits() > PromotedWidth)
        continue;

    // Now handle the others.
    for (unsigned i = 0; i < I->getNumOperands(); ++i) {
      Type *Ty = TruncTysMap[I][i];
      if (Instruction *Trunc = InsertTrunc(I->getOperand(i), Ty)) {
        Trunc->moveBefore(I);
        I->setOperand(i, Trunc);
      }
    }
  }
}

void IRPromoter::Cleanup() {
  LLVM_DEBUG(dbgs() << "IR Promotion: Cleanup..\n");
  // Some zexts will now have become redundant, along with their trunc
  // operands, so remove them.
  for (auto *V : Visited) {
    if (!isa<ZExtInst>(V))
      continue;

    auto ZExt = cast<ZExtInst>(V);
    if (ZExt->getDestTy() != ExtTy)
      continue;

    Value *Src = ZExt->getOperand(0);
    if (ZExt->getSrcTy() == ZExt->getDestTy()) {
      LLVM_DEBUG(dbgs() << "IR Promotion: Removing unnecessary cast: " << *ZExt
                        << "\n");
      ReplaceAllUsersOfWith(ZExt, Src);
      continue;
    }

    // We've inserted a trunc for a zext sink, but we already know that the
    // input is in range, negating the need for the trunc.
    if (NewInsts.count(Src) && isa<TruncInst>(Src)) {
      auto *Trunc = cast<TruncInst>(Src);
      assert(Trunc->getOperand(0)->getType() == ExtTy &&
             "expected inserted trunc to be operating on i32");
      ReplaceAllUsersOfWith(ZExt, Trunc->getOperand(0));
    }
  }

  for (auto *I : InstsToRemove) {
    LLVM_DEBUG(dbgs() << "IR Promotion: Removing " << *I << "\n");
    I->dropAllReferences();
    I->eraseFromParent();
  }
}

void IRPromoter::ConvertTruncs() {
  LLVM_DEBUG(dbgs() << "IR Promotion: Converting truncs..\n");
  IRBuilder<> Builder{Ctx};

  for (auto *V : Visited) {
    if (!isa<TruncInst>(V) || Sources.count(V))
      continue;

    auto *Trunc = cast<TruncInst>(V);
    Builder.SetInsertPoint(Trunc);
    IntegerType *SrcTy = cast<IntegerType>(Trunc->getOperand(0)->getType());
    IntegerType *DestTy = cast<IntegerType>(TruncTysMap[Trunc][0]);

    unsigned NumBits = DestTy->getScalarSizeInBits();
    ConstantInt *Mask =
        ConstantInt::get(SrcTy, APInt::getMaxValue(NumBits).getZExtValue());
    Value *Masked = Builder.CreateAnd(Trunc->getOperand(0), Mask);
    if (SrcTy != ExtTy)
      Masked = Builder.CreateTrunc(Masked, ExtTy);

    if (auto *I = dyn_cast<Instruction>(Masked))
      NewInsts.insert(I);

    ReplaceAllUsersOfWith(Trunc, Masked);
  }
}

void IRPromoter::Mutate() {
  LLVM_DEBUG(dbgs() << "IR Promotion: Promoting use-def chains to "
                    << PromotedWidth << "-bits\n");

  // Cache original types of the values that will likely need truncating
  for (auto *I : Sinks) {
    if (auto *Call = dyn_cast<CallInst>(I)) {
      for (Value *Arg : Call->args())
        TruncTysMap[Call].push_back(Arg->getType());
    } else if (auto *Switch = dyn_cast<SwitchInst>(I))
      TruncTysMap[I].push_back(Switch->getCondition()->getType());
    else {
      for (unsigned i = 0; i < I->getNumOperands(); ++i)
        TruncTysMap[I].push_back(I->getOperand(i)->getType());
    }
  }
  for (auto *V : Visited) {
    if (!isa<TruncInst>(V) || Sources.count(V))
      continue;
    auto *Trunc = cast<TruncInst>(V);
    TruncTysMap[Trunc].push_back(Trunc->getDestTy());
  }

  // Insert zext instructions between sources and their users.
  ExtendSources();

  // Promote visited instructions, mutating their types in place.
  PromoteTree();

  // Convert any truncs, that aren't sources, into AND masks.
  ConvertTruncs();

  // Insert trunc instructions for use by calls, stores etc...
  TruncateSinks();

  // Finally, remove unecessary zexts and truncs, delete old instructions and
  // clear the data structures.
  Cleanup();

  LLVM_DEBUG(dbgs() << "IR Promotion: Mutation complete\n");
}

/// We disallow booleans to make life easier when dealing with icmps but allow
/// any other integer that fits in a scalar register. Void types are accepted
/// so we can handle switches.
bool TypePromotion::isSupportedType(Value *V) {
  Type *Ty = V->getType();

  // Allow voids and pointers, these won't be promoted.
  if (Ty->isVoidTy() || Ty->isPointerTy())
    return true;

  if (!isa<IntegerType>(Ty) || cast<IntegerType>(Ty)->getBitWidth() == 1 ||
      cast<IntegerType>(Ty)->getBitWidth() > RegisterBitWidth)
    return false;

  return LessOrEqualTypeSize(V);
}

/// We accept most instructions, as well as Arguments and ConstantInsts. We
/// Disallow casts other than zext and truncs and only allow calls if their
/// return value is zeroext. We don't allow opcodes that can introduce sign
/// bits.
bool TypePromotion::isSupportedValue(Value *V) {
  if (auto *I = dyn_cast<Instruction>(V)) {
    switch (I->getOpcode()) {
    default:
      return isa<BinaryOperator>(I) && isSupportedType(I) &&
             !GenerateSignBits(I);
    case Instruction::GetElementPtr:
    case Instruction::Store:
    case Instruction::Br:
    case Instruction::Switch:
      return true;
    case Instruction::PHI:
    case Instruction::Select:
    case Instruction::Ret:
    case Instruction::Load:
    case Instruction::Trunc:
    case Instruction::BitCast:
      return isSupportedType(I);
    case Instruction::ZExt:
      return isSupportedType(I->getOperand(0));
    case Instruction::ICmp:
      // Now that we allow small types than TypeSize, only allow icmp of
      // TypeSize because they will require a trunc to be legalised.
      // TODO: Allow icmp of smaller types, and calculate at the end
      // whether the transform would be beneficial.
      if (isa<PointerType>(I->getOperand(0)->getType()))
        return true;
      return EqualTypeSize(I->getOperand(0));
    case Instruction::Call: {
      // Special cases for calls as we need to check for zeroext
      // TODO We should accept calls even if they don't have zeroext, as they
      // can still be sinks.
      auto *Call = cast<CallInst>(I);
      return isSupportedType(Call) &&
             Call->hasRetAttr(Attribute::AttrKind::ZExt);
    }
    }
  } else if (isa<Constant>(V) && !isa<ConstantExpr>(V)) {
    return isSupportedType(V);
  } else if (isa<Argument>(V))
    return isSupportedType(V);

  return isa<BasicBlock>(V);
}

/// Check that the type of V would be promoted and that the original type is
/// smaller than the targeted promoted type. Check that we're not trying to
/// promote something larger than our base 'TypeSize' type.
bool TypePromotion::isLegalToPromote(Value *V) {
  auto *I = dyn_cast<Instruction>(V);
  if (!I)
    return true;

  if (SafeToPromote.count(I))
    return true;

  if (isPromotedResultSafe(I) || isSafeWrap(I)) {
    SafeToPromote.insert(I);
    return true;
  }
  return false;
}

bool TypePromotion::TryToPromote(Value *V, unsigned PromotedWidth) {
  Type *OrigTy = V->getType();
  TypeSize = OrigTy->getPrimitiveSizeInBits().getFixedSize();
  SafeToPromote.clear();
  SafeWrap.clear();

  if (!isSupportedValue(V) || !shouldPromote(V) || !isLegalToPromote(V))
    return false;

  LLVM_DEBUG(dbgs() << "IR Promotion: TryToPromote: " << *V << ", from "
                    << TypeSize << " bits to " << PromotedWidth << "\n");

  SetVector<Value *> WorkList;
  SetVector<Value *> Sources;
  SetVector<Instruction *> Sinks;
  SetVector<Value *> CurrentVisited;
  WorkList.insert(V);

  // Return true if V was added to the worklist as a supported instruction,
  // if it was already visited, or if we don't need to explore it (e.g.
  // pointer values and GEPs), and false otherwise.
  auto AddLegalInst = [&](Value *V) {
    if (CurrentVisited.count(V))
      return true;

    // Ignore GEPs because they don't need promoting and the constant indices
    // will prevent the transformation.
    if (isa<GetElementPtrInst>(V))
      return true;

    if (!isSupportedValue(V) || (shouldPromote(V) && !isLegalToPromote(V))) {
      LLVM_DEBUG(dbgs() << "IR Promotion: Can't handle: " << *V << "\n");
      return false;
    }

    WorkList.insert(V);
    return true;
  };

  // Iterate through, and add to, a tree of operands and users in the use-def.
  while (!WorkList.empty()) {
    Value *V = WorkList.pop_back_val();
    if (CurrentVisited.count(V))
      continue;

    // Ignore non-instructions, other than arguments.
    if (!isa<Instruction>(V) && !isSource(V))
      continue;

    // If we've already visited this value from somewhere, bail now because
    // the tree has already been explored.
    // TODO: This could limit the transform, ie if we try to promote something
    // from an i8 and fail first, before trying an i16.
    if (AllVisited.count(V))
      return false;

    CurrentVisited.insert(V);
    AllVisited.insert(V);

    // Calls can be both sources and sinks.
    if (isSink(V))
      Sinks.insert(cast<Instruction>(V));

    if (isSource(V))
      Sources.insert(V);

    if (!isSink(V) && !isSource(V)) {
      if (auto *I = dyn_cast<Instruction>(V)) {
        // Visit operands of any instruction visited.
        for (auto &U : I->operands()) {
          if (!AddLegalInst(U))
            return false;
        }
      }
    }

    // Don't visit users of a node which isn't going to be mutated unless its a
    // source.
    if (isSource(V) || shouldPromote(V)) {
      for (Use &U : V->uses()) {
        if (!AddLegalInst(U.getUser()))
          return false;
      }
    }
  }

  LLVM_DEBUG({
    dbgs() << "IR Promotion: Visited nodes:\n";
    for (auto *I : CurrentVisited)
      I->dump();
  });

  unsigned ToPromote = 0;
  unsigned NonFreeArgs = 0;
  SmallPtrSet<BasicBlock *, 4> Blocks;
  for (auto *V : CurrentVisited) {
    if (auto *I = dyn_cast<Instruction>(V))
      Blocks.insert(I->getParent());

    if (Sources.count(V)) {
      if (auto *Arg = dyn_cast<Argument>(V))
        if (!Arg->hasZExtAttr() && !Arg->hasSExtAttr())
          ++NonFreeArgs;
      continue;
    }

    if (Sinks.count(cast<Instruction>(V)))
      continue;
    ++ToPromote;
  }

  // DAG optimizations should be able to handle these cases better, especially
  // for function arguments.
  if (ToPromote < 2 || (Blocks.size() == 1 && (NonFreeArgs > SafeWrap.size())))
    return false;

  IRPromoter Promoter(*Ctx, PromotedWidth, CurrentVisited, Sources, Sinks,
                      SafeWrap);
  Promoter.Mutate();
  return true;
}

bool TypePromotion::runOnFunction(Function &F) {
  if (skipFunction(F) || DisablePromotion)
    return false;

  LLVM_DEBUG(dbgs() << "IR Promotion: Running on " << F.getName() << "\n");

  auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
  if (!TPC)
    return false;

  AllVisited.clear();
  SafeToPromote.clear();
  SafeWrap.clear();
  bool MadeChange = false;
  const DataLayout &DL = F.getParent()->getDataLayout();
  const TargetMachine &TM = TPC->getTM<TargetMachine>();
  const TargetSubtargetInfo *SubtargetInfo = TM.getSubtargetImpl(F);
  const TargetLowering *TLI = SubtargetInfo->getTargetLowering();
  const TargetTransformInfo &TII =
      getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  RegisterBitWidth =
      TII.getRegisterBitWidth(TargetTransformInfo::RGK_Scalar).getFixedSize();
  Ctx = &F.getParent()->getContext();

  // Search up from icmps to try to promote their operands.
  for (BasicBlock &BB : F) {
    for (Instruction &I : BB) {
      if (AllVisited.count(&I))
        continue;

      if (!isa<ICmpInst>(&I))
        continue;

      auto *ICmp = cast<ICmpInst>(&I);
      // Skip signed or pointer compares
      if (ICmp->isSigned() || !isa<IntegerType>(ICmp->getOperand(0)->getType()))
        continue;

      LLVM_DEBUG(dbgs() << "IR Promotion: Searching from: " << *ICmp << "\n");

      for (auto &Op : ICmp->operands()) {
        if (auto *I = dyn_cast<Instruction>(Op)) {
          EVT SrcVT = TLI->getValueType(DL, I->getType());
          if (SrcVT.isSimple() && TLI->isTypeLegal(SrcVT.getSimpleVT()))
            break;

          if (TLI->getTypeAction(*Ctx, SrcVT) !=
              TargetLowering::TypePromoteInteger)
            break;
          EVT PromotedVT = TLI->getTypeToTransformTo(*Ctx, SrcVT);
          if (RegisterBitWidth < PromotedVT.getFixedSizeInBits()) {
            LLVM_DEBUG(dbgs() << "IR Promotion: Couldn't find target register "
                              << "for promoted type\n");
            break;
          }

          MadeChange |= TryToPromote(I, PromotedVT.getFixedSizeInBits());
          break;
        }
      }
    }
  }

  AllVisited.clear();
  SafeToPromote.clear();
  SafeWrap.clear();

  return MadeChange;
}

INITIALIZE_PASS_BEGIN(TypePromotion, DEBUG_TYPE, PASS_NAME, false, false)
INITIALIZE_PASS_END(TypePromotion, DEBUG_TYPE, PASS_NAME, false, false)

char TypePromotion::ID = 0;

FunctionPass *llvm::createTypePromotionPass() { return new TypePromotion(); }