diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index e1e356ed8..6d9ddc1c0 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -545,7 +545,7 @@ class SchedulerBase : public Scheduler2D { mBlock[1] = mThdSize[1]; } bsize = KRef * mBlock[1] * mEleSize[1] * 2; - size_t csize = mBlock[0] * mBlock[1] * mEleSize[2]; + size_t csize = static_cast(mBlock[0]) * mBlock[1] * mEleSize[2]; auto rawk = static_cast((valid_total - csize) / (mStep[0] * mEleSize[0] + mBlock[1] * mEleSize[1]) / 2); rawk = std::min(rawk, mSizePadded[2]); mBlock[2] = utils::padto_le(rawk, mStep[2]); @@ -634,14 +634,14 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { } else { this->mBlock[1] = this->mThdSize[1]; } - auto rawk = static_cast((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2]) / 2 / + size_t csize = static_cast(mBlock[0]) * mBlock[1] * this->mEleSize[2]; + auto rawk = static_cast((valid_total - csize) / 2 / (this->mStep[0] * this->mEleSize[0] + float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock + this->mBlock[1] * this->mEleSize[1])); if (rawk < this->mKBlock) { - rawk = static_cast((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] - - 1 * CorSize * (this->mStep[0] + this->mBlock[1])) / - 2 / (this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1])); + rawk = static_cast((valid_total - csize - 1 * CorSize * (this->mStep[0] + this->mBlock[1])) / 2 / + (this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1])); } rawk = std::min(rawk, this->mSizePadded[2]); this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]);