From 76497880dc7376cf2114ead06c8cbd00dbd0914e Mon Sep 17 00:00:00 2001 From: Cydral <53169060+Cydral@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:58:45 +0100 Subject: [PATCH] Update --- dlib/dnn/layers.h | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index f7454e6adc..a73ee7fb69 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -4750,6 +4750,7 @@ namespace dlib void forward(const SUBNET& sub, resizable_tensor& output) { const auto& prev_output = sub.get_output(); + if (!have_same_dimensions(prev_output, pe)) setup(sub); output.set_size(prev_output.num_samples(), prev_output.k(), sequence_dim, embedding_dim); tt::add(output, prev_output, pe); } @@ -4767,22 +4768,16 @@ namespace dlib const tensor& get_positional_encodings() const { return pe; } tensor& get_positional_encodings() { return pe; } - friend void serialize(const positional_encodings_& item, std::ostream& out) + friend void serialize(const positional_encodings_& /*item*/, std::ostream& out) { serialize("positional_encodings_", out); - serialize(item.pe, out); - serialize(item.sequence_dim, out); - serialize(item.embedding_dim, out); } - friend void deserialize(positional_encodings_& item, std::istream& in) + friend void deserialize(positional_encodings_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "positional_encodings_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::positional_encodings_."); - deserialize(item.pe, in); - deserialize(item.sequence_dim, in); - deserialize(item.embedding_dim, in); } friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& /*item*/)