We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Describe the bug this is link to #1500 Can we reuse/plug MPSGraphGRU API directly through mlx ? MetalPerformanceShadersGraph/MPSGraphRNNOps.h
To Reproduce
Include code snippet
std::tuple<Tensor, Tensor> gru_cell_mps( const Tensor& input, // x_t const Tensor& hidden_state, // h_{t-1} const Tensor& w_ih, // Input weight (W_z, W_r, W_h concatenated) const Tensor& w_hh, // Hidden weight (U_z, U_r, U_h concatenated) const Tensor& b_ih, // Bias for input weights (b_z, b_r, b_h concatenated) const Tensor& b_hh // Bias for hidden weights (secondary bias if reset_after = YES) ) { using namespace mps; MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { MPSGraph* graph = [[MPSGraph alloc] init]; // Define graph tensors MPSGraphTensor* inputTensor = [graph constantWithTensor:input]; MPSGraphTensor* hiddenTensor = [graph constantWithTensor:hidden_state]; MPSGraphTensor* wIhTensor = [graph constantWithTensor:w_ih]; MPSGraphTensor* wHhTensor = [graph constantWithTensor:w_hh]; MPSGraphTensor* bIhTensor = [graph constantWithTensor:b_ih]; MPSGraphTensor* bHhTensor = [graph constantWithTensor:b_hh]; // Create GRU descriptor MPSGraphGRUDescriptor* descriptor = [MPSGraphGRUDescriptor descriptor]; descriptor.reverse = NO; descriptor.bidirectional = NO; descriptor.training = NO; // No training state needed descriptor.resetGateFirst = YES; // Assuming gate order is r, z, h descriptor.resetAfter = YES; // Use the "reset-after" formulation // Apply GRU operation for a single step (T=1) NSArray<MPSGraphTensor*>* gruOutput = [graph GRUWithSourceTensor:inputTensor recurrentWeight:wHhTensor inputWeight:wIhTensor bias:bIhTensor initState:hiddenTensor secondaryBias:bHhTensor descriptor:descriptor name:@"gru_cell"]; // Extract output tensors MPSGraphTensor* nextHiddenState = gruOutput[0]; // h_t // Allocate output tensor Tensor outputHidden = at::empty_like(hidden_state); NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary]; NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary]; Placeholder hiddenPlaceholder(nextHiddenState, outputHidden); [results setObject:hiddenPlaceholder.getMPSGraphTensorData() forKey:hiddenPlaceholder.getMPSGraphTensor()]; // Run the graph runMPSGraph(stream, graph, feeds, results); return std::make_tuple(outputHidden, outputHidden); } }
looking at pytorch MPS interface, they link LSTM but not GRU https://github.com/pytorch/pytorch/blob/1886e33f6096175e6f0f77f4b44a39110d2656d6/aten/src/ATen/native/mps/operations/RnnOps.mm.
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Describe the bug
this is link to #1500
Can we reuse/plug MPSGraphGRU API directly through mlx ?
MetalPerformanceShadersGraph/MPSGraphRNNOps.h
To Reproduce
Include code snippet
looking at pytorch MPS interface, they link LSTM but not GRU https://github.com/pytorch/pytorch/blob/1886e33f6096175e6f0f77f4b44a39110d2656d6/aten/src/ATen/native/mps/operations/RnnOps.mm.
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: