#include "VECustomDAG.h"
#include "VEISelLowering.h"
using namespace llvm;
#define DEBUG_TYPE "ve-lower"
SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
SelectionDAG &DAG) const {
VECustomDAG CDAG(DAG, Op);
SDValue AVL =
CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
SDValue A = Op->getOperand(0);
SDValue B = Op->getOperand(1);
SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
unsigned Opc = Op.getOpcode();
auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
}
SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
const unsigned Opcode = Op->getOpcode();
auto VVPOpcodeOpt = getVVPOpcode(Opcode);
if (!VVPOpcodeOpt)
return SDValue();
unsigned VVPOpcode = VVPOpcodeOpt.value();
const bool FromVP = ISD::isVPOpcode(Opcode);
VECustomDAG CDAG(DAG, Op);
switch (VVPOpcode) {
case VEISD::VVP_LOAD:
case VEISD::VVP_STORE:
return lowerVVP_LOAD_STORE(Op, CDAG);
case VEISD::VVP_GATHER:
case VEISD::VVP_SCATTER:
return lowerVVP_GATHER_SCATTER(Op, CDAG);
}
EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
SDValue AVL;
SDValue Mask;
if (FromVP) {
auto MaskIdx = ISD::getVPMaskIdx(Opcode);
auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
if (MaskIdx)
Mask = Op->getOperand(*MaskIdx);
if (AVLIdx)
AVL = Op->getOperand(*AVLIdx);
}
if (!AVL)
AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
if (!Mask)
Mask = CDAG.getConstantMask(Packing, true);
assert(LegalVecVT.isSimple());
if (isVVPUnaryOp(VVPOpcode))
return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
if (isVVPBinaryOp(VVPOpcode))
return CDAG.getNode(VVPOpcode, LegalVecVT,
{Op->getOperand(0), Op->getOperand(1), Mask, AVL});
if (isVVPReductionOp(VVPOpcode)) {
auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
VectorV, Mask, AVL, Op->getFlags());
}
switch (VVPOpcode) {
default:
llvm_unreachable("lowerToVVP called for unexpected SDNode.");
case VEISD::VVP_FFMA: {
auto X = Op->getOperand(2);
auto Y = Op->getOperand(0);
auto Z = Op->getOperand(1);
return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
}
case VEISD::VVP_SELECT: {
auto Mask = Op->getOperand(0);
auto OnTrue = Op->getOperand(1);
auto OnFalse = Op->getOperand(2);
return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
}
case VEISD::VVP_SETCC: {
EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
auto LHS = Op->getOperand(0);
auto RHS = Op->getOperand(1);
auto Pred = Op->getOperand(2);
return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
}
}
}
SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
VECustomDAG &CDAG) const {
auto VVPOpc = *getVVPOpcode(Op->getOpcode());
const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
SDValue BasePtr = getMemoryPtr(Op);
SDValue Mask = getNodeMask(Op);
SDValue Chain = getNodeChain(Op);
SDValue AVL = getNodeAVL(Op);
SDValue Data = getStoredValue(Op);
SDValue PassThru = getNodePassthru(Op);
SDValue StrideV = getLoadStoreStride(Op, CDAG);
auto DataVT = *getIdiomaticVectorType(Op.getNode());
auto Packing = getTypePacking(DataVT);
if (!AVL)
AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
if (!Mask)
Mask = CDAG.getConstantMask(Packing, true);
if (IsLoad) {
MVT LegalDataVT = getLegalVectorType(
Packing, DataVT.getVectorElementType().getSimpleVT());
auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
{Chain, BasePtr, StrideV, Mask, AVL});
if (!PassThru || PassThru->isUndef())
return NewLoadV;
SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
{NewLoadV, PassThru, Mask, AVL});
SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
return CDAG.getMergeValues({DataV, NewLoadChainV});
}
assert(VVPOpc == VEISD::VVP_STORE);
return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
{Chain, Data, BasePtr, StrideV, Mask, AVL});
}
SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
VECustomDAG &CDAG) const {
auto VVPOC = *getVVPOpcode(Op.getOpcode());
assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
assert(getTypePacking(DataVT) == Packing::Dense &&
"Can only split packed load/store");
MVT SplitDataVT = splitVectorType(DataVT);
assert(!getNodePassthru(Op) &&
"Should have been folded in lowering to VVP layer");
SDValue PackedMask = getNodeMask(Op);
SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
SDValue PackPtr = getMemoryPtr(Op);
SDValue PackData = getStoredValue(Op);
SDValue PackStride = getLoadStoreStride(Op, CDAG);
unsigned ChainResIdx = PackData ? 0 : 1;
SDValue PartOps[2];
SDValue UpperPartAVL; for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
if (Part == PackElem::Hi)
UpperPartAVL = SplitTM.AVL;
SmallVector<SDValue, 4> OpVec;
OpVec.push_back(getNodeChain(Op));
if (PackData) {
SDValue PartData =
CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
OpVec.push_back(PartData);
}
OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
OpVec.push_back(SplitTM.Mask);
OpVec.push_back(SplitTM.AVL);
if (PackData) {
PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
} else {
PartOps[(int)Part] =
CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
}
}
SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
SDValue FusedChains =
CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
if (PackData)
return FusedChains;
MVT PackedVT =
getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
PartOps[(int)PackElem::Hi], UpperPartAVL);
return CDAG.getMergeValues({PackedVals, FusedChains});
}
SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
VECustomDAG &CDAG) const {
EVT DataVT = *getIdiomaticVectorType(Op.getNode());
auto Packing = getTypePacking(DataVT);
MVT LegalDataVT =
getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
SDValue AVL = getAnnotatedNodeAVL(Op).first;
SDValue Index = getGatherScatterIndex(Op);
SDValue BasePtr = getMemoryPtr(Op);
SDValue Mask = getNodeMask(Op);
SDValue Chain = getNodeChain(Op);
SDValue Scale = getGatherScatterScale(Op);
SDValue PassThru = getNodePassthru(Op);
SDValue StoredValue = getStoredValue(Op);
if (PassThru && PassThru->isUndef())
PassThru = SDValue();
bool IsScatter = (bool)StoredValue;
if (!AVL)
AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
if (!Mask)
Mask = CDAG.getConstantMask(Packing, true);
SDValue AddressVec =
CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
if (IsScatter)
return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
{Chain, StoredValue, AddressVec, Mask, AVL});
SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
{Chain, AddressVec, Mask, AVL});
if (!PassThru)
return NewLoadV;
SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
{NewLoadV, PassThru, Mask, AVL});
SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
return CDAG.getMergeValues({DataV, NewLoadChainV});
}
SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
VECustomDAG &CDAG) const {
LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
if (isPackedVectorType(DataVT))
return splitPackedLoadStore(Op, CDAG);
return legalizePackedAVL(Op, CDAG);
}
SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
VECustomDAG CDAG(DAG, Op);
switch (Op->getOpcode()) {
case VEISD::VVP_LOAD:
case VEISD::VVP_STORE:
return legalizeInternalLoadStoreOp(Op, CDAG);
}
EVT IdiomVT = Op.getValueType();
if (isPackedVectorType(IdiomVT) &&
!supportsPackedMode(Op.getOpcode(), IdiomVT))
return splitVectorOp(Op, CDAG);
return legalizePackedAVL(Op, CDAG);
}
SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
auto AVLPos = getAVLPos(Op->getOpcode());
auto MaskPos = getMaskPos(Op->getOpcode());
SDValue PackedMask = getNodeMask(Op);
auto AVLPair = getAnnotatedNodeAVL(Op);
SDValue PackedAVL = AVLPair.first;
assert(!AVLPair.second && "Expecting non pack-legalized oepration");
SDValue PartOps[2];
SDValue UpperPartAVL; for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
if (Part == PackElem::Hi)
UpperPartAVL = SplitTM.AVL;
SmallVector<SDValue, 4> OpVec;
for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
if (AVLPos && ((int)i) == *AVLPos)
continue;
if (MaskPos && ((int)i) == *MaskPos)
continue;
auto PackedOperand = Op.getOperand(i);
auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
SDValue PartV =
CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
OpVec.push_back(PartV);
}
OpVec.push_back(SplitTM.Mask);
OpVec.push_back(SplitTM.AVL);
PartOps[(int)Part] =
CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
}
return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
PartOps[(int)PackElem::Hi], UpperPartAVL);
}
SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
VECustomDAG &CDAG) const {
LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
if (!isVVPOrVEC(Op->getOpcode()))
return Op;
auto AVL = getNodeAVL(Op);
if (isLegalAVL(AVL))
return Op;
SDValue LegalAVL = AVL;
MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
if (isPackedVectorType(IdiomVT)) {
assert(maySafelyIgnoreMask(Op) &&
"TODO Shift predication from EVL into Mask");
if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
} else {
auto ConstOne = CDAG.getConstant(1, MVT::i32);
auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
}
}
SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
int NumOp = Op->getNumOperands();
auto AVLPos = getAVLPos(Op->getOpcode());
std::vector<SDValue> FixedOperands;
for (int i = 0; i < NumOp; ++i) {
if (AVLPos && (i == *AVLPos)) {
FixedOperands.push_back(AnnotatedLegalAVL);
continue;
}
FixedOperands.push_back(Op->getOperand(i));
}
auto Flags = Op->getFlags();
SDValue NewN =
CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
return NewN;
}