Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Bug] Segmentation fault when using the MergeCompositeFunctions transform #17120

Closed
Cookiee235 opened this issue Jun 27, 2024 · 2 comments · Fixed by #17220
Closed

[Relax][Bug] Segmentation fault when using the MergeCompositeFunctions transform #17120

Cookiee235 opened this issue Jun 27, 2024 · 2 comments · Fixed by #17220
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Cookiee235 commented Jun 27, 2024

Actual behavior

Segmentation fault (core dumped)

Environment

TVM: 0.17.dev0
OS: Ubuntu20.04

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(x11: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0 in range(T.int64(10)):
            with T.block("compute"):
                v_i0 = T.axis.spatial(T.int64(10), i0)
                T.reads(x11[v_i0])
                T.writes(compute[v_i0])
                compute[v_i0] = T.max(x11[v_i0], T.float32(0))

    @R.function(private=True)
    def fused_relax_nn_gelu(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
        R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
        cls = Module
        with R.dataflow():
            gv3 = R.nn.gelu(x21)
            R.output(gv3)
        return gv3

    @R.function(private=True)
    def fused_relax_nn_relu(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
        R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
        cls = Module
        with R.dataflow():
            # gv2 = R.call_tir(cls.relu, (x11,), out_sinfo=R.Tensor((10,), dtype="float32"))
            gv2 = R.nn.relu(x11)
            R.output(gv2)
        return gv2

    @R.function
    def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv1: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu(x1)
            lv2 = R.call_tir(cls.relu, (lv1,), out_sinfo=R.Tensor((10,), dtype="float32"))
            lv3: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_gelu(lv2)
            R.output(lv3)
        return lv3

mod = Module
mod.show()
mod = relax.transform.MergeCompositeFunctions()(mod)  #seg fault

Triage

  • needs-triage

cc @junrushao

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Jun 27, 2024
@Cookiee235 Cookiee235 changed the title [Relax][Bug] Segmentation fault When execute the MergeCompositeFunctions [Relax][Bug] Segmentation fault when use the MergeCompositeFunctions transform Jun 27, 2024
@Cookiee235 Cookiee235 changed the title [Relax][Bug] Segmentation fault when use the MergeCompositeFunctions transform [Relax][Bug] Segmentation fault when using the MergeCompositeFunctions transform Jun 27, 2024
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Jul 30, 2024
Prior to this commit, use of `R.call_tir` in the input to
`MergeCompositeFunctions` would result in a segfault, when attempting
to determine the `Group*` that contains the `relax::GlobalVar` of the
callee.

This commit updates `MergeCompositeFunctions` to check for
`relax::GlobalVar` and `relax::Tuple` instances.

Closes apache#17120
@Lunderberg
Copy link
Contributor

I'd been hoping that this one would be resolved incidentally through #17212, and while it did change the segfault to an exception, it didn't solve the root cause. This bug should now be fixed with #17220.

@Cookiee235
Copy link
Contributor Author

@Lunderberg Thanks a lot! The above test case can run correctly now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants