Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] can we integrate direct MPSGraph features in mlx ? #1585

Open
thegodone opened this issue Nov 13, 2024 · 0 comments
Open

[QUESTION] can we integrate direct MPSGraph features in mlx ? #1585

thegodone opened this issue Nov 13, 2024 · 0 comments

Comments

@thegodone
Copy link

Describe the bug
this is link to #1500
Can we reuse/plug MPSGraphGRU API directly through mlx ?
MetalPerformanceShadersGraph/MPSGraphRNNOps.h

To Reproduce

Include code snippet

std::tuple<Tensor, Tensor> gru_cell_mps(
    const Tensor& input,         // x_t
    const Tensor& hidden_state,  // h_{t-1}
    const Tensor& w_ih,          // Input weight (W_z, W_r, W_h concatenated)
    const Tensor& w_hh,          // Hidden weight (U_z, U_r, U_h concatenated)
    const Tensor& b_ih,          // Bias for input weights (b_z, b_r, b_h concatenated)
    const Tensor& b_hh           // Bias for hidden weights (secondary bias if reset_after = YES)
) {
    using namespace mps;

    MPSStream* stream = getCurrentMPSStream();
    @autoreleasepool {
        MPSGraph* graph = [[MPSGraph alloc] init];

        // Define graph tensors
        MPSGraphTensor* inputTensor = [graph constantWithTensor:input];
        MPSGraphTensor* hiddenTensor = [graph constantWithTensor:hidden_state];
        MPSGraphTensor* wIhTensor = [graph constantWithTensor:w_ih];
        MPSGraphTensor* wHhTensor = [graph constantWithTensor:w_hh];
        MPSGraphTensor* bIhTensor = [graph constantWithTensor:b_ih];
        MPSGraphTensor* bHhTensor = [graph constantWithTensor:b_hh];

        // Create GRU descriptor
        MPSGraphGRUDescriptor* descriptor = [MPSGraphGRUDescriptor descriptor];
        descriptor.reverse = NO;
        descriptor.bidirectional = NO;
        descriptor.training = NO;  // No training state needed
        descriptor.resetGateFirst = YES;  // Assuming gate order is r, z, h
        descriptor.resetAfter = YES;     // Use the "reset-after" formulation

        // Apply GRU operation for a single step (T=1)
        NSArray<MPSGraphTensor*>* gruOutput = [graph GRUWithSourceTensor:inputTensor
                                                        recurrentWeight:wHhTensor
                                                            inputWeight:wIhTensor
                                                                   bias:bIhTensor
                                                              initState:hiddenTensor
                                                          secondaryBias:bHhTensor
                                                             descriptor:descriptor
                                                                   name:@"gru_cell"];

        // Extract output tensors
        MPSGraphTensor* nextHiddenState = gruOutput[0];  // h_t

        // Allocate output tensor
        Tensor outputHidden = at::empty_like(hidden_state);

        NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
        NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary];

        Placeholder hiddenPlaceholder(nextHiddenState, outputHidden);
        [results setObject:hiddenPlaceholder.getMPSGraphTensorData() forKey:hiddenPlaceholder.getMPSGraphTensor()];

        // Run the graph
        runMPSGraph(stream, graph, feeds, results);

        return std::make_tuple(outputHidden, outputHidden);
    }
}

looking at pytorch MPS interface, they link LSTM but not GRU https://github.com/pytorch/pytorch/blob/1886e33f6096175e6f0f77f4b44a39110d2656d6/aten/src/ATen/native/mps/operations/RnnOps.mm.

Desktop (please complete the following information):

  • OS Version: [e.g. MacOS 15.1]
  • Version [e.g. 0.20]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant