Skip to content

Commit

Permalink
Operations with call and return
Browse files Browse the repository at this point in the history
  • Loading branch information
jtristan committed Aug 20, 2024
1 parent 1dae639 commit e108188
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 111 deletions.
1 change: 1 addition & 0 deletions SHerLOC/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ import SHerLOC.Identifiers
import SHerLOC.Constants
import SHerLOC.Operations
import SHerLOC.Functions
import SHerLOC.Programs
1 change: 1 addition & 0 deletions SHerLOC/Identifiers.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Authors: Jean-Baptiste Tristan

def FuncId := Nat
def ValueId := Nat
def UnusedId := Nat
238 changes: 132 additions & 106 deletions SHerLOC/Operations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,115 +5,141 @@ Authors: Jean-Baptiste Tristan
-/
import SHerLOC.Constants
import SHerLOC.Identifiers
import SHerLOC.Types

/-!
# Operations
-/

inductive Operation where
| abs (operand : ValueId) (result : ValueId)
| add (lhs rhs : ValueId) (result : ValueId)
| after_all (inputs : ValueId) (result : ValueId)
| all_gather (operands : ValueId) (all_gather_dim replica_groups channel_id use_global_device_ids : Constant) (results : ValueId)
| all_reduce (operands : ValueId) (computation : FuncId) (replica_groups channel_id use_global_device_ids : Constant) (results : ValueId)
| all_to_all (operands : ValueId) (split_dimension concat_dimension split_count replica_groups channel_id : Constant) (results : ValueId)
| and (lhs rhs : ValueId) (result : ValueId)
| atan2 (lhs rhs : ValueId) (result : ValueId)
| batch_norm_grad (operand scale mean variance grad_output : ValueId) (epsilon feature_index : Constant) (grad_operand grad_scale grad_offset : ValueId)
| batch_norm_inference (operand scale offset mean variance : ValueId) (epsilon feature_index : Constant) (result : ValueId)
| batch_norm_training (operand scale offset : ValueId) (epsilon feature_index : Constant) (output batch_mean batch_var : ValueId)
| bitcast_convert (operand : ValueId) (result : ValueId)
| broadcast_in_dim (operand : ValueId) (broadcast_dimensions : Constant) (result : ValueId)
| case (index branches : ValueId) (results : ValueId)
| cbrt (operand : ValueId) (result : ValueId)
| ceil (operand : ValueId) (result : ValueId)
| cholesky (a : ValueId) (lower : Constant) (result : ValueId)
| clamp (min operand max : ValueId) (result : ValueId)
| collective_broadcast (operand : ValueId) (replica_groups channel_id : Constant) (result : ValueId)
| collective_permute (operand : ValueId) (source_target_pairs channel_id : Constant) (result : ValueId)
| compare (lhs rhs comparison_direction compare_type : ValueId) (result : ValueId)
| complex (lhs rhs : ValueId) (result : ValueId)
| composite (inputs composite_attributes : ValueId) (name decomposition version : Constant) (results : ValueId)
| concatenate (inputs : ValueId) (dimension : Constant) (result : ValueId)
| constant (value : Constant) (output : ValueId)
| convert (operand : ValueId) (result : ValueId)
| convolution (lhs rhs precision_config : ValueId) (window_strides padding lhs_dilation rhs_dilation window_reversal input_batch_dimension input_feature_dimension input_spatial_dimensions kernel_input_feature_dimension kernel_output_feature_dimension kernel_spatial_dimensions output_batch_dimension output_feature_dimension output_spatial_dimensions feature_group_count batch_group_count : Constant) (result : ValueId)
| cosine (operand : ValueId) (result : ValueId)
| count_leading_zeros (operand : ValueId) (result : ValueId)
| custom_call (inputs : ValueId) (call_target_name has_side_effect backend_config api_version called_computations : Constant) (results : ValueId)
| divide (lhs rhs : ValueId) (result : ValueId)
| dot_general (lhs rhs precision_config lhs_precision_type rhs_precision_type accumulation_type : ValueId) (lhs_batching_dimensions rhs_batching_dimensions lhs_contracting_dimensions rhs_contracting_dimensions lhs_component_count rhs_component_count num_primitive_operations allow_imprecise_accumulation : Constant) (result : ValueId)
| dynamic_broadcast_in_dim (operand output_dimensions : ValueId) (broadcast_dimensions known_expanding_dimensions known_non_expanding_dimensions : Constant) (result : ValueId)
| dynamic_conv (lhs rhs padding precision_config : ValueId) (window_strides lhs_dilation rhs_dilation window_reversal input_batch_dimension input_feature_dimension input_spatial_dimensions kernel_input_feature_dimension kernel_output_feature_dimension kernel_spatial_dimensions output_batch_dimension output_feature_dimension output_spatial_dimensions feature_group_count batch_group_count : Constant) (result : ValueId)
| dynamic_gather (operand start_indices slice_sizes : ValueId) (offset_dims collapsed_slice_dims start_index_map index_vector_dim indices_are_sorted : Constant) (result : ValueId)
| dynamic_iota (output_shape iota_dimension : ValueId) (result : ValueId)
| dynamic_pad (operand padding_value edge_padding_low edge_padding_high interior_padding : ValueId) (result : ValueId)
| dynamic_reshape (operand output_shape : ValueId) (result : ValueId)
| dynamic_slice (operand start_indices : ValueId) (slice_sizes : Constant) (result : ValueId)
| dynamic_update_slice (operand update start_indices : ValueId) (result : ValueId)
| exponential (operand : ValueId) (result : ValueId)
| exponential_minus_one (operand : ValueId) (result : ValueId)
| fft (operand fft_type : ValueId) (fft_length : Constant) (result : ValueId)
| floor (operand : ValueId) (result : ValueId)
| gather (operand start_indices : ValueId) (offset_dims collapsed_slice_dims operand_batching_dims start_indices_batching_dims start_index_map index_vector_dim slice_sizes indices_are_sorted : Constant) (result : ValueId)
| get_dimension_size (operand : ValueId) (dimension : Constant) (result : ValueId)
| get_tuple_element (operand : ValueId) (index : Constant) (result : ValueId)
| if (pred : ValueId) (true_branch false_branch : FuncId) (results : ValueId)
| imag (operand : ValueId) (result : ValueId)
| infeed (token : ValueId) (infeed_config : Constant) (results : ValueId)
| iota (iota_dimension : ValueId) (output : ValueId)
| is_finite (x : ValueId) (y : ValueId)
| log (operand : ValueId) (result : ValueId)
| log_plus_one (operand : ValueId) (result : ValueId)
| logistic (operand : ValueId) (result : ValueId)
| map (inputs : ValueId) (computation : FuncId) (dimensions : Constant) (result : ValueId)
| maximum (lhs rhs : ValueId) (result : ValueId)
| minimum (lhs rhs : ValueId) (result : ValueId)
| multiply (lhs rhs : ValueId) (result : ValueId)
| negate (operand : ValueId) (result : ValueId)
| not (operand : ValueId) (result : ValueId)
| optimization_barrier (operand : ValueId) (result : ValueId)
| or (lhs rhs : ValueId) (result : ValueId)
| outfeed (inputs token : ValueId) (outfeed_config : Constant) (result : ValueId)
| pad (operand padding_value : ValueId) (edge_padding_low edge_padding_high interior_padding : Constant) (result : ValueId)
| partition_id (result : ValueId)
| popcnt (operand : ValueId) (result : ValueId)
| power (lhs rhs : ValueId) (result : ValueId)
| real (operand : ValueId) (result : ValueId)
| recv (token channel_type : ValueId) (channel_id is_host_transfer : Constant) (results : ValueId)
| reduce (inputs init_values : ValueId) (body : FuncId) (dimensions : Constant) (results : ValueId)
| reduce_precision (operand : ValueId) (exponent_bits mantissa_bits : Constant) (output : ValueId)
| reduce_scatter (operand : ValueId) (computation : FuncId) (scatter_dimension replica_groups channel_id use_global_device_ids : Constant) (result : ValueId)
| reduce_window (inputs init_values : ValueId) (body : FuncId) (window_dimensions window_strides base_dilations window_dilations padding : Constant) (results : ValueId)
| remainder (lhs rhs : ValueId) (result : ValueId)
| replica_id (result : ValueId)
| reshape (operand : ValueId) (result : ValueId)
| reverse (operand : ValueId) (dimensions : Constant) (result : ValueId)
| rng (a b rng_distribution : ValueId) (shape : Constant) (result : ValueId)
| rng_bit_generator (rng_algorithm initial_state : ValueId) (output_state output : ValueId)
| round_nearest_afz (operand : ValueId) (result : ValueId)
| round_nearest_even (operand : ValueId) (result : ValueId)
| rsqrt (operand : ValueId) (result : ValueId)
| scatter (inputs scatter_indices updates : ValueId) (update_computation : FuncId) (update_window_dims inserted_window_dims input_batching_dims scatter_indices_batching_dims scatter_dims_to_operand_dims index_vector_dim indices_are_sorted unique_indices : Constant) (results : ValueId)
| select (pred on_true on_false : ValueId) (result : ValueId)
| select_and_scatter (operand source init_value : ValueId) (select scatter : FuncId) (window_dimensions window_strides padding : Constant) (result : ValueId)
| send (inputs token channel_type : ValueId) (channel_id is_host_transfer : Constant) (result : ValueId)
| shift_left (lhs rhs : ValueId) (result : ValueId)
| shift_right_arithmetic (lhs rhs : ValueId) (result : ValueId)
| shift_right_logical (lhs rhs : ValueId) (result : ValueId)
| sign (operand : ValueId) (result : ValueId)
| sine (operand : ValueId) (result : ValueId)
| slice (operand : ValueId) (start_indices limit_indices strides : Constant) (result : ValueId)
| sort (inputs : ValueId) (comparator : FuncId) (dimension is_stable : Constant) (results : ValueId)
| sqrt (operand : ValueId) (result : ValueId)
| subtract (lhs rhs : ValueId) (result : ValueId)
| tan (operand : ValueId) (result : ValueId)
| tanh (operand : ValueId) (result : ValueId)
| transpose (operand : ValueId) (permutation : Constant) (result : ValueId)
| triangular_solve (a b transpose_a : ValueId) (left_side lower unit_diagonal : Constant) (result : ValueId)
| tuple (val : ValueId) (result : ValueId)
| uniform_dequantize (operand : ValueId) (result : ValueId)
| uniform_quantize (operand : ValueId) (result : ValueId)
| while (operand : ValueId) (cond body : FuncId) (results : ValueId)
| xor (lhs rhs : ValueId) (result : ValueId)
inductive OpName where
| abs
| add
| after_all
| all_gather
| all_reduce
| all_to_all
| and
| atan2
| batch_norm_grad
| batch_norm_inference
| batch_norm_training
| bitcast_convert
| broadcast_in_dim
| case
| cbrt
| ceil
| cholesky
| clamp
| collective_broadcast
| collective_permute
| compare
| complex
| composite
| concatenate
| constant
| convert
| convolution
| cosine
| count_leading_zeros
| custom_call
| divide
| dot_general
| dynamic_broadcast_in_dim
| dynamic_conv
| dynamic_gather
| dynamic_iota
| dynamic_pad
| dynamic_reshape
| dynamic_slice
| dynamic_update_slice
| exponential
| exponential_minus_one
| fft
| floor
| gather
| get_dimension_size
| get_tuple_element
| if
| imag
| infeed
| iota
| is_finite
| log
| log_plus_one
| logistic
| map
| maximum
| minimum
| multiply
| negate
| not
| optimization_barrier
| or
| outfeed
| pad
| partition_id
| popcnt
| power
| real
| recv
| reduce
| reduce_precision
| reduce_scatter
| reduce_window
| remainder
| replica_id
| reshape
| reverse
| rng
| rng_bit_generator
| round_nearest_afz
| round_nearest_even
| rsqrt
| scatter
| select
| select_and_scatter
| send
| shift_left
| shift_right_arithmetic
| shift_right_logical
| sign
| sine
| slice
| sort
| sqrt
| subtract
| tan
| tanh
| transpose
| triangular_solve
| tuple
| uniform_dequantize
| uniform_quantize
| while
| xor

mutual

inductive InputFunc where
| mk
(id : UnusedId)
(funcInputs : List ValueId)
(body : List Operation)

inductive Operation where
| stable
(name : OpName)(inputValues : List ValueId)
(inputFunctions : List InputFunc)
(inputAttributes : List Constant)
(outputs : List ValueId)
(signature : FunctionType)
| return
(operands : ValueId)
(signature : ValueType)
| call
(callee : FuncId)
(arguments : ValueId)
(signature : FunctionType)

end
11 changes: 6 additions & 5 deletions SHerLOC/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ inductive QuantizedTensorElementType where
| quant : Signedness → IntegerSize → Int → Int → FloatSize → Int → List (Float × Int) → QuantizedTensorElementType

inductive ValueType where
| tensorType : List Int → TensorElementType → ValueType
| quantizedTensorType : List Int → QuantizedTensorElementType → ValueType
| tensorType (shape : List Int) (typ : TensorElementType)
| quantizedTensorType (shape : List Int) (typ : QuantizedTensorElementType)
| tokenType
| tupleType : List Valuetype → ValueType
| tupleType (elements : List Valuetype)

inductive TensorFloar32 where

inductive StringType where

inductive FunctionType where
| functionType : List ValueType → List ValueType → FunctionType
structure FunctionType where
domain : List ValueType
range : List ValueType

0 comments on commit e108188

Please sign in to comment.