Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Cydral authored Nov 1, 2024
1 parent b9b1c88 commit 7649788
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4750,6 +4750,7 @@ namespace dlib
void forward(const SUBNET& sub, resizable_tensor& output)
{
const auto& prev_output = sub.get_output();
if (!have_same_dimensions(prev_output, pe)) setup(sub);
output.set_size(prev_output.num_samples(), prev_output.k(), sequence_dim, embedding_dim);
tt::add(output, prev_output, pe);
}
Expand All @@ -4767,22 +4768,16 @@ namespace dlib
const tensor& get_positional_encodings() const { return pe; }
tensor& get_positional_encodings() { return pe; }

friend void serialize(const positional_encodings_& item, std::ostream& out)
friend void serialize(const positional_encodings_& /*item*/, std::ostream& out)
{
serialize("positional_encodings_", out);
serialize(item.pe, out);
serialize(item.sequence_dim, out);
serialize(item.embedding_dim, out);
}
friend void deserialize(positional_encodings_& item, std::istream& in)
friend void deserialize(positional_encodings_& /*item*/, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "positional_encodings_")
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::positional_encodings_.");
deserialize(item.pe, in);
deserialize(item.sequence_dim, in);
deserialize(item.embedding_dim, in);
}

friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& /*item*/)
Expand Down

0 comments on commit 7649788

Please sign in to comment.