Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32 minute compile time for max_pool2d_with_indices #8429

Open
jansel opened this issue Sep 28, 2024 · 3 comments
Open

32 minute compile time for max_pool2d_with_indices #8429

jansel opened this issue Sep 28, 2024 · 3 comments
Assignees
Labels
autoscheduler Related to one or more of the Autoschedulers

Comments

@jansel
Copy link
Contributor

jansel commented Sep 28, 2024

This example takes 32 minutes to compile, while typical kernels take seconds (not minutes). I suspect it is hitting some sort of pathological case in Halide.

repro.py

import halide as hl
from torch._inductor.runtime import halide_helpers
from math import inf, nan


@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 5)
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 4)
    out_ptr2 = hl.OutputBuffer(hl.Int(64), 4)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr0 = g.out_ptr0
        out_ptr2 = g.out_ptr2
        h0 = hl.Var("h0")
        h1 = hl.Var("h1")
        h2 = hl.Var("h2")
        h3 = hl.Var("h3")
        tmp0 = hl.Func("tmp0")
        tmp0[h0, h1, h2, h3] = in_ptr0[0, h0, h1, h2, h3]
        tmp1 = hl.Func("tmp1")
        tmp1[h0, h1, h2, h3] = in_ptr0[1, h0, h1, h2, h3]
        tmp2 = hl.Func("tmp2")
        tmp2[h0, h1, h2, h3] = (
            hl.select(
                (tmp1[h0, h1, h2, h3] > hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3]))
                | hl.is_nan(tmp1[h0, h1, h2, h3]),
                tmp1[h0, h1, h2, h3],
                hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3]),
            )
            if tmp1.type().is_float()
            else hl.max(
                tmp1[h0, h1, h2, h3], hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3])
            )
        )
        tmp3 = hl.Func("tmp3")
        tmp3[h0, h1, h2, h3] = in_ptr0[0, 1 + h0, h1, h2, h3]
        tmp4 = hl.Func("tmp4")
        tmp4[h0, h1, h2, h3] = (
            hl.select(
                (tmp3[h0, h1, h2, h3] > hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3]))
                | hl.is_nan(tmp3[h0, h1, h2, h3]),
                tmp3[h0, h1, h2, h3],
                hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3]),
            )
            if tmp3.type().is_float()
            else hl.max(
                tmp3[h0, h1, h2, h3], hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3])
            )
        )
        tmp5 = hl.Func("tmp5")
        tmp5[h0, h1, h2, h3] = in_ptr0[1, 13 + h0, h1, h2, h3]
        tmp6 = hl.Func("tmp6")
        tmp6[h0, h1, h2, h3] = (
            hl.select(
                (tmp5[h0, h1, h2, h3] > hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3]))
                | hl.is_nan(tmp5[h0, h1, h2, h3]),
                tmp5[h0, h1, h2, h3],
                hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3]),
            )
            if tmp5.type().is_float()
            else hl.max(
                tmp5[h0, h1, h2, h3], hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3])
            )
        )
        tmp7 = hl.Func("tmp7")
        tmp7[h0, h1, h2, h3] = in_ptr0[0, 14 + h0, h1, h2, h3]
        tmp8 = hl.Func("tmp8")
        tmp8[h0, h1, h2, h3] = (
            hl.select(
                (tmp7[h0, h1, h2, h3] > hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3]))
                | hl.is_nan(tmp7[h0, h1, h2, h3]),
                tmp7[h0, h1, h2, h3],
                hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3]),
            )
            if tmp7.type().is_float()
            else hl.max(
                tmp7[h0, h1, h2, h3], hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3])
            )
        )
        tmp9 = hl.Func("tmp9")
        tmp9[h0, h1, h2, h3] = in_ptr0[1, 14 + h0, h1, h2, h3]
        tmp10 = hl.Func("tmp10")
        tmp10[h0, h1, h2, h3] = (
            hl.select(
                (tmp9[h0, h1, h2, h3] > hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3]))
                | hl.is_nan(tmp9[h0, h1, h2, h3]),
                tmp9[h0, h1, h2, h3],
                hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3]),
            )
            if tmp9.type().is_float()
            else hl.max(
                tmp9[h0, h1, h2, h3], hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3])
            )
        )
        tmp11 = hl.Func("tmp11")
        tmp11[h0, h1, h2, h3] = in_ptr0[0, h0, 1 + h1, h2, h3]
        tmp12 = hl.Func("tmp12")
        tmp12[h0, h1, h2, h3] = (
            hl.select(
                (tmp11[h0, h1, h2, h3] > hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3]))
                | hl.is_nan(tmp11[h0, h1, h2, h3]),
                tmp11[h0, h1, h2, h3],
                hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3]),
            )
            if tmp11.type().is_float()
            else hl.max(
                tmp11[h0, h1, h2, h3], hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3])
            )
        )
        tmp13 = hl.Func("tmp13")
        tmp13[h0, h1, h2, h3] = in_ptr0[1, h0, 1 + h1, h2, h3]
        tmp14 = hl.Func("tmp14")
        tmp14[h0, h1, h2, h3] = (
            hl.select(
                (tmp13[h0, h1, h2, h3] > hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3]))
                | hl.is_nan(tmp13[h0, h1, h2, h3]),
                tmp13[h0, h1, h2, h3],
                hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3]),
            )
            if tmp13.type().is_float()
            else hl.max(
                tmp13[h0, h1, h2, h3], hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3])
            )
        )
        tmp15 = hl.Func("tmp15")
        tmp15[h0, h1, h2, h3] = in_ptr0[0, 1 + h0, 1 + h1, h2, h3]
        tmp16 = hl.Func("tmp16")
        tmp16[h0, h1, h2, h3] = (
            hl.select(
                (tmp15[h0, h1, h2, h3] > hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3]))
                | hl.is_nan(tmp15[h0, h1, h2, h3]),
                tmp15[h0, h1, h2, h3],
                hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3]),
            )
            if tmp15.type().is_float()
            else hl.max(
                tmp15[h0, h1, h2, h3], hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3])
            )
        )
        out_ptr0[h0, h1, h2, h3] = hl.cast(hl.Float(32), tmp16[h0, h1, h2, h3])
        tmp17 = hl.Func("tmp17")
        tmp17[h0, h1, h2, h3] = tmp1[h0, h1, h2, h3] > tmp0[h0, h1, h2, h3]
        tmp18 = hl.Func("tmp18")
        tmp18[()] = hl.cast(hl.Int(8), 1)
        tmp19 = hl.Func("tmp19")
        tmp19[()] = hl.cast(hl.Int(8), 0)
        tmp20 = hl.Func("tmp20")
        tmp20[h0, h1, h2, h3] = hl.select(
            tmp17[h0, h1, h2, h3], tmp18[()], hl.cast(tmp18.type(), tmp19[()])
        )
        tmp21 = hl.Func("tmp21")
        tmp21[h0, h1, h2, h3] = tmp3[h0, h1, h2, h3] > tmp2[h0, h1, h2, h3]
        tmp22 = hl.Func("tmp22")
        tmp22[()] = hl.cast(hl.Int(8), 2)
        tmp23 = hl.Func("tmp23")
        tmp23[h0, h1, h2, h3] = hl.select(
            tmp21[h0, h1, h2, h3],
            tmp22[()],
            hl.cast(tmp22.type(), tmp20[h0, h1, h2, h3]),
        )
        tmp24 = hl.Func("tmp24")
        tmp24[h0, h1, h2, h3] = tmp5[h0, h1, h2, h3] > tmp4[h0, h1, h2, h3]
        tmp25 = hl.Func("tmp25")
        tmp25[()] = hl.cast(hl.Int(8), 3)
        tmp26 = hl.Func("tmp26")
        tmp26[h0, h1, h2, h3] = hl.select(
            tmp24[h0, h1, h2, h3],
            tmp25[()],
            hl.cast(tmp25.type(), tmp23[h0, h1, h2, h3]),
        )
        tmp27 = hl.Func("tmp27")
        tmp27[h0, h1, h2, h3] = tmp7[h0, h1, h2, h3] > tmp6[h0, h1, h2, h3]
        tmp28 = hl.Func("tmp28")
        tmp28[()] = hl.cast(hl.Int(8), 4)
        tmp29 = hl.Func("tmp29")
        tmp29[h0, h1, h2, h3] = hl.select(
            tmp27[h0, h1, h2, h3],
            tmp28[()],
            hl.cast(tmp28.type(), tmp26[h0, h1, h2, h3]),
        )
        tmp30 = hl.Func("tmp30")
        tmp30[h0, h1, h2, h3] = tmp9[h0, h1, h2, h3] > tmp8[h0, h1, h2, h3]
        tmp31 = hl.Func("tmp31")
        tmp31[()] = hl.cast(hl.Int(8), 5)
        tmp32 = hl.Func("tmp32")
        tmp32[h0, h1, h2, h3] = hl.select(
            tmp30[h0, h1, h2, h3],
            tmp31[()],
            hl.cast(tmp31.type(), tmp29[h0, h1, h2, h3]),
        )
        tmp33 = hl.Func("tmp33")
        tmp33[h0, h1, h2, h3] = tmp11[h0, h1, h2, h3] > tmp10[h0, h1, h2, h3]
        tmp34 = hl.Func("tmp34")
        tmp34[()] = hl.cast(hl.Int(8), 6)
        tmp35 = hl.Func("tmp35")
        tmp35[h0, h1, h2, h3] = hl.select(
            tmp33[h0, h1, h2, h3],
            tmp34[()],
            hl.cast(tmp34.type(), tmp32[h0, h1, h2, h3]),
        )
        tmp36 = hl.Func("tmp36")
        tmp36[h0, h1, h2, h3] = tmp13[h0, h1, h2, h3] > tmp12[h0, h1, h2, h3]
        tmp37 = hl.Func("tmp37")
        tmp37[()] = hl.cast(hl.Int(8), 7)
        tmp38 = hl.Func("tmp38")
        tmp38[h0, h1, h2, h3] = hl.select(
            tmp36[h0, h1, h2, h3],
            tmp37[()],
            hl.cast(tmp37.type(), tmp35[h0, h1, h2, h3]),
        )
        tmp39 = hl.Func("tmp39")
        tmp39[h0, h1, h2, h3] = tmp15[h0, h1, h2, h3] > tmp14[h0, h1, h2, h3]
        tmp40 = hl.Func("tmp40")
        tmp40[()] = hl.cast(hl.Int(8), 8)
        tmp41 = hl.Func("tmp41")
        tmp41[h0, h1, h2, h3] = hl.select(
            tmp39[h0, h1, h2, h3],
            tmp40[()],
            hl.cast(tmp40.type(), tmp38[h0, h1, h2, h3]),
        )
        tmp42 = hl.Func("tmp42")
        tmp42[()] = hl.cast(hl.Int(32), 3)
        tmp43 = hl.Func("tmp43")
        tmp43[h0, h1, h2, h3] = hl.floor(
            hl.cast(hl.Float(max(32, tmp41.type().bits())), tmp41[h0, h1, h2, h3])
            / tmp42[()]
        )
        tmp44 = hl.Func("tmp44")
        tmp44[h0, h1, h2, h3] = tmp43[h0, h1, h2, h3] * tmp42[()]
        tmp45 = hl.Func("tmp45")
        tmp45[h0, h1, h2, h3] = tmp41[h0, h1, h2, h3] - tmp44[h0, h1, h2, h3]
        tmp46 = hl.Func("tmp46")
        tmp46[h1] = 2 * h1
        tmp47 = hl.Func("tmp47")
        tmp47[h0, h1, h2, h3] = tmp46[h1] + tmp43[h0, h1, h2, h3]
        tmp48 = hl.Func("tmp48")
        tmp48[h0] = 2 * h0
        tmp49 = hl.Func("tmp49")
        tmp49[h0, h1, h2, h3] = tmp48[h0] + tmp45[h0, h1, h2, h3]
        tmp50 = hl.Func("tmp50")
        tmp50[()] = hl.cast(hl.Int(64), 27)
        tmp51 = hl.Func("tmp51")
        tmp51[h0, h1, h2, h3] = tmp47[h0, h1, h2, h3] * tmp50[()]
        tmp52 = hl.Func("tmp52")
        tmp52[h0, h1, h2, h3] = tmp51[h0, h1, h2, h3] + tmp49[h0, h1, h2, h3]
        out_ptr2[h0, h1, h2, h3] = hl.cast(hl.Int(64), tmp52[h0, h1, h2, h3])

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(0).set_extent(2)
        in_ptr0.dim(1).set_min(0)
        in_ptr0.dim(1).set_stride(2)
        in_ptr0.dim(1).set_extent(13)
        in_ptr0.dim(2).set_min(0)
        in_ptr0.dim(2).set_stride(54)
        in_ptr0.dim(2).set_extent(13)
        in_ptr0.dim(3).set_min(0)
        in_ptr0.dim(3).set_stride(729)
        in_ptr0.dim(3).set_extent(192)
        in_ptr0.dim(4).set_min(0)
        in_ptr0.dim(4).set_stride(139968)
        in_ptr0.dim(4).set_extent(128)
        in_ptr0.set_estimates(
            [
                hl.Range(0, 2),
                hl.Range(0, 13),
                hl.Range(0, 13),
                hl.Range(0, 192),
                hl.Range(0, 128),
            ]
        )
        out_ptr0.set_estimates(
            [hl.Range(0, 13), hl.Range(0, 13), hl.Range(0, 192), hl.Range(0, 128)]
        )
        out_ptr2.set_estimates(
            [hl.Range(0, 13), hl.Range(0, 13), hl.Range(0, 192), hl.Range(0, 128)]
        )


if __name__ == "__main__":
    import sys, tempfile

    with tempfile.TemporaryDirectory() as out:
        sys.argv = [
            "repro.py",
            "-g",
            "kernel",
            "-o",
            out,
            "-f",
            "halide_kernel",
            "-e",
            "static_library,h,schedule",
            "-p",
            "/home/jansel/conda/envs/pytorch/lib/python3.12/site-packages/halide/lib64/libautoschedule_anderson2021.so",
            "target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
            "autoscheduler=Anderson2021",
            "autoscheduler.parallelism=82",
        ]
        hl.main()

cc @alexreinking this example coming from:

python benchmarks/dynamo/microbenchmarks/operatorbench.py --inductor-config autotune --inductor-config halide --op aten.max_pool2d_with_indices.default --max-samples 1 --start-idx 4

on pytorch/pytorch#136809

@abadams
Copy link
Member

abadams commented Sep 28, 2024

So it takes 32 minutes... but still successfully compiles? Interesting. Maybe there's a lurking pass with exponential complexity for this example.

@jansel
Copy link
Contributor Author

jansel commented Sep 28, 2024

Yeah, it finishes and runs correctly.

@abadams abadams added the autoscheduler Related to one or more of the Autoschedulers label Sep 30, 2024
@abadams
Copy link
Member

abadams commented Sep 30, 2024

Looks like it's not compilation proper, but rather the anderson autoscheduler getting stuck enumerating a combinatorial number of tiling options, which is a bit absurd given that this entire pipeline seems to be elementwise other than accesses to the input buffer.

A workaround would be to ask the autoscheduler to do a lot less by generating an Expr instead of a Func for anything that has no update definition and is either consumed elementwise or is an op that is cheaper than a load (e.g. tmp48).

@alexreinking alexreinking self-assigned this Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autoscheduler Related to one or more of the Autoschedulers
Projects
None yet
Development

No branches or pull requests

3 participants