Skip to content

Commit

Permalink
fix: Support globs in build strings
Browse files Browse the repository at this point in the history
Signed-off-by: Julien Jerphanion <[email protected]>
  • Loading branch information
jjerphan committed Jan 8, 2025
1 parent 979162f commit 844e4bc
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 15 deletions.
2 changes: 1 addition & 1 deletion libmamba/include/mamba/specs/regex_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace mamba::specs
[[nodiscard]] static auto parse(std::string pattern) -> expected_parse_t<RegexSpec>;

RegexSpec();
RegexSpec(std::regex pattern, std::string raw_pattern);
RegexSpec(std::string raw_pattern);

[[nodiscard]] auto contains(std::string_view str) const -> bool;

Expand Down
51 changes: 37 additions & 14 deletions libmamba/src/specs/regex_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <algorithm>
#include <cassert>
#include <iostream>
#include <sstream>

#include <fmt/format.h>

Expand All @@ -25,12 +27,10 @@ namespace mamba::specs

auto RegexSpec::parse(std::string pattern) -> expected_parse_t<RegexSpec>
{
// No other mean of getting parse result with ``std::regex``, but parse error need
// to be handled by ``tl::expected`` to be managed down the road.
// Parse error need to be handled by ``tl::expected`` to be managed down the road.
try
{
auto regex = std::regex(pattern);
return { { std::move(regex), std::move(pattern) } };
return { std::move(pattern) };
}
catch (const std::regex_error& e)
{
Expand All @@ -39,25 +39,48 @@ namespace mamba::specs
}

RegexSpec::RegexSpec()
: RegexSpec(std::regex(free_pattern.data(), free_pattern.size()), std::string(free_pattern))
: RegexSpec(std::string(free_pattern))
{
}

RegexSpec::RegexSpec(std::regex pattern, std::string raw_pattern)
: m_pattern(std::move(pattern))
, m_raw_pattern(std::move(raw_pattern))
RegexSpec::RegexSpec(std::string raw_pattern)
{
// Construct ss from raw_pattern, in particular make sure to replace all `*` by `.*`
// in the pattern if they are not preceded by a `.`.
// We force regex to start with `^` and end with `$` to simplify the multiple
// possible representations, and because this is the safest way we can make sure it is
// not a glob when serializing it.
if (!util::starts_with(m_raw_pattern, pattern_start))
{
m_raw_pattern.insert(m_raw_pattern.begin(), pattern_start);
}
if (!util::ends_with(m_raw_pattern, pattern_end))
std::ostringstream ss;
ss << pattern_start;

auto first_character_it = raw_pattern.cbegin();
auto last_character_it = raw_pattern.cend() - 1;

for (auto it = first_character_it; it != raw_pattern.cend(); ++it)
{
m_raw_pattern.push_back(pattern_end);
if (it == first_character_it && *it == pattern_start)
{
continue;
}
if (it == last_character_it && *it == pattern_end)
{
continue;
}
if (*it == '*' && (it == first_character_it || *(it - 1) != '.'))
{
ss << ".*";
}
else
{
ss << *it;
}
}

ss << pattern_end;

m_raw_pattern = ss.str();

m_pattern = std::regex(m_raw_pattern);
}

auto RegexSpec::contains(std::string_view str) const -> bool
Expand Down
19 changes: 19 additions & 0 deletions libmamba/tests/src/specs/test_match_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,25 @@ namespace
/* .track_features =*/{ "openssl", "mkl" },
}));
}

SECTION("pytorch=2.3.1=py3.10_cuda11.8*")
{
// Check that it contains `pytorch=2.3.1=py3.10_cuda11.8_cudnn8.7.0_0`

const auto ms = "pytorch=2.3.1=py3.10_cuda11.8*"_ms;

REQUIRE(ms.contains_except_channel(Pkg{
/* .name= */ "pytorch",
/* .version= */ "2.3.1"_v,
/* .build_string= */ "py3.10_cuda11.8_cudnn8.7.0_0",
/* .build_number= */ 0,
/* .md5= */ "lemd5",
/* .sha256= */ "somesha256",
/* .license= */ "GPL",
/* .platform= */ "linux-64",
/* .track_features =*/{},
}));
}
}

TEST_CASE("MatchSpec comparability and hashability")
Expand Down
6 changes: 6 additions & 0 deletions libmamba/tests/src/specs/test_regex_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,10 @@ namespace
REQUIRE(hash_fn(spec1) == hash_fn(spec2));
REQUIRE(hash_fn(spec1) != hash_fn(spec3));
}

TEST_CASE("RegexSpec ^py3.10_cuda11.8*$")
{
auto spec = RegexSpec::parse("^py3.10_cuda11.8*$").value();
REQUIRE(spec.contains("py3.10_cuda11.8_cudnn8.7.0_0"));
}
}
40 changes: 40 additions & 0 deletions micromamba/tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,3 +1575,43 @@ def test_update_spec_list(tmp_path):
out = helpers.create("-p", env_prefix, "-f", env_spec_file, "--dry-run")

assert update_specs_list in out.replace("\r", "")


def test_glob_in_build_string(tmp_path):
# Non-regression test for https://github.com/mamba-org/mamba/issues/3699
env_prefix = tmp_path / "test_glob_in_build_string"

pytorch_match_spec = "pytorch=2.3.1=py3.10_cuda11.8*"

# Export CONDA_OVERRIDE_GLIBC=2.17 to force the solver to use the glibc 2.17 package
try:
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.17"

# Should run without error
out = helpers.create(
"-p",
env_prefix,
pytorch_match_spec,
"-c",
"pytorch",
"-c",
"nvidia/label/cuda-11.8.0",
"-c",
"nvidia",
"-c",
"conda-forge",
"--platform",
"linux-64",
"--dry-run",
"--json",
)
finally:
os.environ.pop("CONDA_OVERRIDE_GLIBC", None)

# Check that a build of pytorch 2.3.1 with `py3.10_cuda11.8_cudnn8.7.0_0` as a build string is found
assert any(
package["name"] == "pytorch"
and package["version"] == "2.3.1"
and package["build_string"] == "py3.10_cuda11.8_cudnn8.7.0_0"
for package in out["actions"]["FETCH"]
)

0 comments on commit 844e4bc

Please sign in to comment.