Skip to content

Commit

Permalink
Merge pull request #28 from RelationalAI/adding-unsafe
Browse files Browse the repository at this point in the history
Adding unsafe
  • Loading branch information
bergel authored Mar 8, 2024
2 parents 2607c86 + d926f59 commit 851d10b
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 18 deletions.
27 changes: 23 additions & 4 deletions src/linting/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,28 @@ LintOptions(::Colon) = LintOptions(fill(true, length(default_options))...)
LintOptions(options::Vararg{Union{Bool,Nothing},length(default_options)}) =
LintOptions(something.(options, default_options)...)


function fetch_value(x::EXPR, tag::Symbol)
if headof(x) == tag
return x.val
else
isnothing(x.args) && return nothing
for i in 1:length(x.args)
r = fetch_value(x.args[i], tag)
isnothing(r) || return r
end
return nothing
end
end

function check_all(x::EXPR, opts::LintOptions, env::ExternalEnv, markers::Dict{Symbol,String}=Dict{Symbol,String}())
# Setting up the markers
if headof(x) === :const
markers[:const] = "const"
markers[:const] = fetch_value(x, :IDENTIFIER)
end

if headof(x) === :function
markers[:function] = fetch_value(x, :IDENTIFIER)
end

# Do checks
Expand Down Expand Up @@ -145,9 +164,9 @@ function check_all(x::EXPR, opts::LintOptions, env::ExternalEnv, markers::Dict{S
end
end

if headof(x) === :const
delete!(markers, :const)
end
# Do some cleaning
headof(x) === :const && delete!(markers, :const)
headof(x) === :function && delete!(markers, :function)
end

function _typeof(x, state)
Expand Down
42 changes: 38 additions & 4 deletions src/linting/extended_checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,30 @@ function is_hole_variable_star(x::CSTParser.EXPR)
end

comp(x, y) = x == y

struct BothCannotHaveStarException <: Exception
msg::String
end

comp_value(x, y) = x == y
function comp_value(x::String, y::String)
is_there_any_star_marker = contains(x, "QQQ") || contains(y, "QQQ")
!is_there_any_star_marker && return x == y

contains(x, "QQQ") && contains(y, "QQQ") && throw(BothCannotHaveStarException("Cannot both $x and $y have a star marker"))
if contains(x, "QQQ")
reg_exp = Regex(replace(x, "QQQ" => ".*"))
return !isnothing(match(reg_exp, y))
else
reg_exp = Regex(replace(y, "QQQ" => ".*"))
return !isnothing(match(reg_exp, x))
end
end

function comp(x::CSTParser.EXPR, y::CSTParser.EXPR)
(is_hole_variable(x) || is_hole_variable(y)) && return true

result = comp(x.head, y.head) && x.val == y.val
result = comp(x.head, y.head) && comp_value(x.val, y.val)
!result && return false

min_length = min(length(x), length(y))
Expand Down Expand Up @@ -71,11 +91,11 @@ struct Channel_Extension <: ExtendedRule end
struct Task_Extension <: ExtendedRule end
struct ErrorException_Extension <: ExtendedRule end
struct Error_Extension <: ExtendedRule end
struct Unsafe_Extension <: ExtendedRule end
struct In_Extension <: ExtendedRule end
struct HasKey_Extension <: ExtendedRule end
struct Equal_Extension <: ExtendedRule end


const all_extended_rule_types = Ref{Any}(InteractiveUtils.subtypes(ExtendedRule))

# template -> EXPR to be compared
Expand Down Expand Up @@ -221,11 +241,25 @@ function check(::Error_Extension, x::EXPR)
"Use custom exception instead of the generic `error(...)`")
end

function check(::Unsafe_Extension, x::EXPR, markers::Dict{Symbol,String})
haskey(markers, :function) || return
isnothing(match(r"_unsafe_.*", markers[:function])) || return
isnothing(match(r"unsafe_.*", markers[:function])) || return

generic_check(
x,
"unsafe_QQQ(hole_variable_star)",
"An `unsafe_` function should be called only from an `unsafe_` function.")
generic_check(
x,
"_unsafe_QQQ(hole_variable_star)",
"An `unsafe_` function should be called only from an `unsafe_` function.")
end

function check(::In_Extension, x::EXPR)
msg = "It is preferable to use `tin(item,collection)` instead of the Julia's `in`."
generic_check(x, "in(hole_variable,hole_variable)", msg)
generic_check(x, "hole_variable in hole_variable", msg)

end

function check(::HasKey_Extension, x::EXPR)
Expand All @@ -236,4 +270,4 @@ end
function check(::Equal_Extension, x::EXPR)
msg = "It is preferable to use `tequal(dict,key)` instead of the Julia's `equal`."
generic_check(x, "equal(hole_variable,hole_variable)", msg)
end
end
112 changes: 102 additions & 10 deletions test/rai_rules_tests.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
using StaticLint: StaticLint, run_lint_on_text, comp, convert_offset_to_line,
convert_offset_to_line_from_lines, should_be_filtered, MarkdownFormat, PlainFormat
convert_offset_to_line_from_lines, should_be_filtered, MarkdownFormat, PlainFormat,
fetch_value
import CSTParser
using Test
using JSON3

# Reset the caches before running the tests.
StaticLint.reset_static_lint_caches()

function foo()
@async 1 + 2
end

const n = Threads.nthreads()

function lint_test(source::String, expected_substring::String; verbose=true, directory::String = "")
io = IOBuffer()
run_lint_on_text(source; io, directory)
Expand Down Expand Up @@ -361,7 +356,7 @@ end
"Line 12, column 10: `mmap` should be used with extreme caution.")
@test lint_test(source,
"Line 13, column 10: `mmap` should be used with extreme caution.")
@test lint_test(source,
@test lint_test(source,
"Line 14, column 12: `Future` should be used with extreme caution.")
@test lint_test(source,
"Line 15, column 12: `Future` should be used with extreme caution.")
Expand Down Expand Up @@ -486,7 +481,7 @@ end
end
end

@testset "Comparison" begin
@testset "Comparing AST to templates" begin
t(s1, s2) = comp(CSTParser.parse(s1), CSTParser.parse(s2))
@test t("Threads.nthreads()", "Threads.nthreads()")
@test !t("QWEThreads.nthreads()", "Threads.nthreads()")
Expand Down Expand Up @@ -536,11 +531,89 @@ end
@test !t("Future()", "Future{hole_variable}(hole_variable_star)")
@test !t("Future{Any}() do f nothing end", "Future{hole_variable}(hole_variable_star)")

# Partial matching
@test t("foo", "foo")
@test t("fooQQQ", "foo")
@test t("QQQfoo", "foo")
@test t("QQQfooQQQ", "foo")

@test !t("foo", "foobar")
@test t("fooQQQ", "foobar")
@test t("QQQfoo", "barfoo")
@test t("QQQfooQQQ", "barfoozork")

@test !t("foo", "foo_bar")
@test t("fooQQQ", "foo_bar")
@test t("QQQfoo", "bar_foo")
@test t("QQQfooQQQ", "bar_foo_zork")

@test t("foo(x, QQQzork)", "foo(x, zork)")
@test t("foo(x, QQQzork)", "foo(x, blah_zork)")

# in keyword
@test t("in(hole_variable,hole_variable)", "in(x,y)")
@test t("x in y", "hole_variable in hole_variable")
end

@testset "unsafe functions" begin
@testset "No error 01" begin
source = """
function unsafe_f()
unsafe_g()
end
function unsafe_g()
return 42
end
"""
@test !lint_has_error_test(source)
end

@testset "Some errors 01" begin
source = """
function f()
unsafe_g()
end
function unsafe_g()
return 42
end
"""
@test lint_has_error_test(source)
@test lint_test(source,
"Line 2, column 5: An `unsafe_` function should be called only from an `unsafe_` function.")
end

@testset "Some errors 02" begin
source = """
function f()
_unsafe_g()
end
function _unsafe_g()
return 42
end
"""
@test lint_has_error_test(source)
@test lint_test(source,
"Line 2, column 5: An `unsafe_` function should be called only from an `unsafe_` function.")
end

@testset "No error 02" begin
source = """
function f()
# lint-disable-next-line
_unsafe_g()
end
function _unsafe_g()
return 42
end
"""
@test !lint_has_error_test(source)
end
end

@testset "offset to line" begin
source = """
function f()
Expand Down Expand Up @@ -568,6 +641,25 @@ end
@test !should_be_filtered(hint_as_string2, filters)
end

@testset "Fetching values from AST" begin
@test fetch_value(CSTParser.parse("f"), :IDENTIFIER) == "f"
@test fetch_value(CSTParser.parse("f()"), :IDENTIFIER) == "f"
@test fetch_value(CSTParser.parse("begin f(g()) end"), :IDENTIFIER) == "f"

source = """
struct _SyncDict{Dict}
lock::Base.Threads.SpinLock
dict::Dict
function _SyncDict{Dict}() where {Dict}
new{Dict}(Base.Threads.SpinLock(), Dict())
end
end
"""
@test fetch_value(CSTParser.parse(source), :IDENTIFIER) == "_SyncDict"

end

@testset "Formatter" begin
source = """
const x = Threads.nthreads()
Expand Down Expand Up @@ -1069,4 +1161,4 @@ end
return 42
end
""")
end
end

0 comments on commit 851d10b

Please sign in to comment.