Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional apply logprob computation at call site instead of construction site #71

Open
jerinphilip opened this issue Feb 22, 2022 · 0 comments
Labels
enhancement New feature or request

Comments

@jerinphilip
Copy link

This issue is meant to track the possibility of some workaround to get QE to be optional at run time, as opposed to construction time (skip-cost= true).

Trace:

skip-cost:

bool skipCost = options->get<bool>("skip-cost");
auto encdec = models::createModelFromOptions(
options, skipCost ? models::usage::raw : models::usage::translation);

createModelFromOptions:

// add (log)softmax if requested
if (use == usage::translation) {
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
if(options->get<bool>("output-sampling", false))
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
else
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());

StepWise:

// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
// wrapped again in the IEncoderDecoder interface
// @TODO: seems we are conflating an interface defition with its implementation?
// @TODO: needs a better name. Stepwise is an adjective. Classes are things=nouns. StepwiseWhat?
class Stepwise : public IEncoderDecoder {

StepWise Relevant call site:

virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
const Words& words, // [beamIndex * activeBatchSize + batchIndex]
const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) override {
auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
return cost_->apply(nextState);
}

If I insert a bool skipCost defaulting to false as part of the arguments here and ignore the cost operation if skipCost=true and trigger the param via beamsearch (see below), there is a possibility?

Call-site:

states[i] = scorers_[i]->step(graph, states[i], hypIndices, prevWords, batchIndices, (int)maxBeamSize);

virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,
Ptr<ScorerState> state,
const std::vector<IndexType>& hypIndices,
const Words& words,
const std::vector<IndexType>& batchIndices,
int beamSize) override {
graph->switchParams(getName());
auto wrapperState = std::dynamic_pointer_cast<ScorerWrapperState>(state);
auto newState = encdec_->step(graph, wrapperState->getState(), hypIndices, words, batchIndices, beamSize);
return New<ScorerWrapperState>(newState);

@jerinphilip jerinphilip added the enhancement New feature or request label Feb 22, 2022
@jerinphilip jerinphilip changed the title Optional apply logprob computation at runtime Optional apply logprob computation at call site instead of construction site Feb 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant